-rwxr-xr-x 41356 saferewrite-20210915/analyze
#!/usr/bin/env python3
compilerlist = (
'clang -O1 -fwrapv -march=native',
'gcc -O3 -march=native -mtune=native',
)
typebits = {
'int8': 8,
'int16': 16,
'int32': 32,
'int64': 64,
'uint8': 8,
'uint16': 16,
'uint32': 32,
'uint64': 64,
}
rusttype = {
'int8': 'i8',
'int16': 'i16',
'int32': 'i32',
'int64': 'i64',
'uint8': 'u8',
'uint16': 'u16',
'uint32': 'u32',
'uint64': 'u64',
}
import sys
import os
import shutil
import subprocess
import angr
import claripy
import multiprocessing
import random
import traceback
import functools
# ----- performance-related configuration...
numrandomtests = 16
# test 16 random inputs to each function
# (first two are all-0, all-1)
# of course more time on testing and fuzzing would be useful
maxsplit = 100
# max number of universes within an angr run
# XXX: allow per-primitive/per-implementation configuration
satvalidation1 = False
# True would mean: even if random tests fail, invoke general sat/unsat mechanism
# this is for extra validation of that mechanism
claripy.simplifications.simpleton._simplifiers = {
'Reverse': claripy.simplifications.simpleton.bv_reverse_simplifier,
'Extract': claripy.simplifications.simpleton.extract_simplifier,
'Concat': claripy.simplifications.simpleton.concat_simplifier,
'ZeroExt': claripy.simplifications.simpleton.zeroext_simplifier,
'SignExt': claripy.simplifications.simpleton.signext_simplifier,
}
# this is somewhat ad-hoc monkey-patching
# underlying problem: flatten simplifiers can easily blow up
sys.setrecursionlimit(1000000)
# ----- real work begins here
z3timeout = 2**32-1
# XXX: by default claripy satisfiable() has 5-minute timeout
# and indicates timeout by returning False
# current Z3 documentation says:
# timeout (in milliseconds) (UINT_MAX and 0 mean no timeout)
try:
os_cores = len(os.sched_getaffinity(0))
except AttributeError:
os_cores = multiprocessing.cpu_count()
os_cores = os.getenv('CORES',default=os_cores)
os_cores = int(os_cores)
if os_cores < 1: os_cores = 1
import resource
def cputime():
return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime
def notetime(builddir,what,time):
print('%s seconds %s %.6f' % (builddir,what,time))
sys.stdout.flush()
with open('%s/analysis/seconds' % builddir,'a') as f:
f.write('%s %.6f\n' % (what,time))
def note(builddir,conclusion,contents=None):
print('%s %s' % (builddir,conclusion))
sys.stdout.flush()
with open('%s/analysis/%s' % (builddir,conclusion),'w') as f:
if contents != None:
f.write(str(contents))
startdir = os.getcwd()
assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir)
shutil.rmtree('build',ignore_errors=True)
os.makedirs('build')
primitives = []
for o in 'src',:
o = o.strip()
if o == '': continue
if not os.path.isdir(o): continue
if os.stat('%s' % o).st_mode & 0o1000 == 0o1000:
print('%s sticky, skipping' % o)
sys.stdout.flush()
continue
for p in sorted(os.listdir(o)):
if not os.path.isdir('%s/%s' % (o,p)): continue
if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000:
print('%s/%s sticky, skipping' % (o,p))
sys.stdout.flush()
continue
if not os.path.exists('%s/%s/api' % (o,p)):
print('%s/%s/api nonexistent, skipping' % (o,p))
sys.stdout.flush()
continue
primitives += [(o,p)]
op_api = {}
for o,p in primitives:
inputs = []
outputs = []
funargs = []
funargtypes = []
rustargs = []
rustargtypes = []
crustargs = []
crustargtypes = []
funname = None
funret = None
funrettype = 'void'
with open('%s/%s/api' % (o,p)) as f:
for line in f:
line = line.split()
if len(line) == 0: continue
if line[0] == 'call':
funname = line[1]
assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_' for x in funname)
if line[0] == 'return':
symboltype = line[1]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
entries = 1
outputs += [(csymbol,symboltype,entries)]
funret = 'alloc_%s'%csymbol
funrettype = 'uint%d_t'%typebits[symboltype]
if line[0] in ('in','out','inout'):
symboltype = line[1]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
if len(line) == 3:
pointer = False
entries = 1
else:
pointer = True
entries = int(line[3])
if line[0] in ('in','inout'):
inputs += [(csymbol,symboltype,entries)]
if line[0] in ('out','inout'):
outputs += [(csymbol,symboltype,entries)]
if pointer:
funargs += ['alloc_%s'%csymbol]
funargtypes += ['uint%d_t *' % typebits[symboltype]]
crustargs += ['c_%s'%csymbol]
rustargs += ['rust_%s'%csymbol]
if line[0] == 'in':
crustargtypes += ['*const %s' % rusttype[symboltype]]
else:
crustargtypes += ['*mut %s' % rusttype[symboltype]]
else:
funargs += ['*alloc_%s'%csymbol]
funargtypes += ['uint%d_t' % typebits[symboltype]]
crustargs += ['c_%s'%csymbol]
rustargs += ['rust_%s'%csymbol]
crustargtypes += ['const %s' % rusttype[symboltype]]
# XXX: support constant inputs
op_api[o,p] = inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype
def input_example_str(inputs,x):
xstr = ''
xpos = 0
for csymbol,symboltype,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
xstr += '%s = %d\n' % (varname,x[xpos])
xpos += 1
assert xpos == len(x)
return xstr
def output_example_str(outputs,y):
ystr = ''
ypos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
ystr += '%s = %d\n' % (varname,y[ypos])
ypos += 1
assert ypos == len(y)
return ystr
reservedfilenames = (
'library.so.1',
'analysis',
'analysis-execute',
'analysis-execute.c',
'analysis-valgrind',
'analysis-valgrind.c',
'analysis-angr',
'analysis-angr.c',
)
opimplementations = {}
for o,p in primitives:
opimplementations[o,p] = []
for i in sorted(os.listdir('%s/%s' % (o,p))):
implementationdir = '%s/%s/%s' % (o,p,i)
if not os.path.isdir(implementationdir): continue
if os.stat(implementationdir).st_mode & 0o1000 == 0o1000:
print('%s/%s/%s sticky, skipping' % (o,p,i))
continue
files = sorted(os.listdir(implementationdir))
for f in files:
ok = True
if f in reservedfilenames:
print('%s/%s/%s/%s reserved filename' % (o,p,i,f))
ok = False
if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f):
print('%s/%s/%s/%s prohibited character' % (o,p,i,f))
ok = False
if not ok: continue
opimplementations[o,p] += [i]
for compiler in compilerlist:
compilerword = compiler.replace(' ','_').replace('=','_')
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
os.makedirs('build/%s/%s' % (p,i),exist_ok=True)
def compile(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p]
files = sorted(os.listdir(implementationdir))
rust = 'Cargo.toml' in files
rustlocked = 'Cargo.lock' in files
cfiles = [x for x in files if x.endswith('.c')]
sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')]
files = cfiles + sfiles
shutil.copytree(implementationdir,builddir)
os.makedirs('%s/analysis' % builddir)
for bits in 8,16,32,64:
with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_int%d int%d_t' % (bits,bits))
with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_uint%d uint%d_t' % (bits,bits))
for analysis in 'execute','valgrind','angr':
with open('%s/analysis-%s.c' % (builddir,analysis),'w') as f:
f.write('#include <stdio.h>\n')
f.write('#include <stdlib.h>\n')
f.write('#include <string.h>\n')
f.write('#include <inttypes.h>\n')
f.write('\n')
# function declaration
f.write('extern ')
if funrettype != None:
f.write('%s ' % funrettype)
f.write('%s(%s);\n' % (funname,','.join(funargtypes)))
f.write('\n')
for csymbol,symboltype,entries in inputs:
f.write('uint%d_t static_%s[%d];\n' % (typebits[symboltype],csymbol,entries))
for csymbol,symboltype,entries in outputs:
if (csymbol,symboltype,entries) not in inputs:
f.write('uint%d_t static_%s[%d];\n' % (typebits[symboltype],csymbol,entries))
f.write('\n')
f.write('int main(int argc,char **argv)\n')
f.write('{\n')
for csymbol,symboltype,entries in inputs:
f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (typebits[symboltype],csymbol,entries*typebits[symboltype]/8))
for csymbol,symboltype,entries in outputs:
if (csymbol,symboltype,entries) not in inputs:
f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (typebits[symboltype],csymbol,entries*typebits[symboltype]/8))
f.write('\n')
# XXX: resource limits
if analysis == 'execute':
for csymbol,symboltype,entries in inputs:
f.write(' for (long long i = 0;i < %d;++i) {\n' % entries)
f.write(' unsigned long long x;\n')
f.write(' if (scanf("%llu",&x) != 1) abort();\n')
f.write(' static_%s[i] = x;\n' % csymbol)
f.write(' }\n')
f.write('\n')
if analysis in ('execute','angr'):
for csymbol,symboltype,entries in inputs:
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' alloc_%s[i] = static_%s[i];\n' % (csymbol,csymbol))
f.write('\n')
f.write(' ')
if funret != None:
f.write('%s[0] = ' % funret)
f.write('%s(%s);\n' % (funname,','.join(funargs)))
f.write('\n')
if analysis in ('execute','angr'):
for csymbol,symboltype,entries in outputs:
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' static_%s[i] = alloc_%s[i];\n' % (csymbol,csymbol))
f.write('\n')
if analysis == 'execute':
for csymbol,symboltype,entries in outputs:
f.write(' for (long long i = 0;i < %d;++i) {\n' % entries)
f.write(' unsigned long long x = static_%s[i];\n' % csymbol)
f.write(' printf("%llu\\n",x);\n')
f.write(' }\n')
f.write(' fflush(stdout);\n')
f.write('\n')
f.write(' return 0;\n')
f.write('}\n')
# ----- rust
if rust:
os.makedirs('%s/src/bin'%builddir,exist_ok=True)
with open('%s/Cargo.toml'%builddir,'a') as f:
f.write("""\
[build-dependencies]
cc = { git = "https://github.com/alexcrichton/cc-rs", version = "1.0.67", rev = "3283434ee41f2a1be6c78fe6d6a3ec8eb92b1833" }
[[bin]]
name = "analysis-execute"
[[bin]]
name = "analysis-valgrind"
[[bin]]
name = "analysis-angr"
[profile.release]
panic = "abort"
""")
with open('%s/build.rs'%builddir,'w') as f:
f.write("""\
use std::env;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
println!("cargo:rustc-link-search=native={}",out_dir);
cc::Build::new()
.cargo_metadata(false)
.file("analysis-execute.c")
.compile("analysis-execute");
println!("cargo:rerun-if-changed=analysis-execute.c");
cc::Build::new()
.cargo_metadata(false)
.file("analysis-valgrind.c")
.compile("analysis-valgrind");
println!("cargo:rerun-if-changed=analysis-valgrind.c");
cc::Build::new()
.cargo_metadata(false)
.file("analysis-angr.c")
.compile("analysis-angr");
println!("cargo:rerun-if-changed=analysis-angr.c");
}
""")
for analysis in 'execute','valgrind','angr':
with open('%s/src/bin/analysis-%s.rs'%(builddir,analysis),'w') as f:
f.write('#![no_std]\n')
f.write('#![no_main]\n')
f.write('extern crate %s;\n'%funname)
f.write('\n')
f.write('use core::convert::TryInto;\n')
f.write('use core::slice;\n')
f.write('\n')
f.write('#[link(name="analysis-%s")]\n'%analysis)
f.write('extern {\n')
f.write('#[allow(dead_code)]\n')
f.write(' fn main() -> i32;\n')
f.write('}\n')
f.write('\n')
f.write('#[no_mangle]\n')
f.write('pub unsafe extern "C"\n')
f.write('fn %s(\n'%funname)
for arg,argtype in zip(crustargs,crustargtypes):
f.write(' %s:%s,\n' % (arg,argtype))
f.write(') -> i32 {\n')
for csymbol,symboltype,entries in outputs:
f.write(' let rust_%s = slice::from_raw_parts_mut(c_%s,%d);\n' % (csymbol,csymbol,entries))
f.write(' let rust_%s:&mut[%s;%d] = rust_%s.try_into().unwrap();\n' % (csymbol,rusttype[symboltype],entries,csymbol))
for csymbol,symboltype,entries in inputs:
if (csymbol,symboltype,entries) not in outputs:
f.write(' let rust_%s = slice::from_raw_parts(c_%s,%d);\n' % (csymbol,csymbol,entries))
f.write(' %s::%s(%s);\n' % (funname,funname,','.join(rustargs)))
# XXX: return value
f.write(' 0\n')
f.write('}\n')
cargotime = -cputime()
env = os.environ
env['CC'] = compiler.split()[0]
env['CFLAGS'] = ' '.join(compiler.split()[1:])
env['RUSTFLAGS'] = '-C target_cpu=native -C link-args=-no-pie' # XXX: let user control this?
if rustlocked:
command = 'cargo build -q --locked --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr'
else:
command = 'cargo build -q --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr'
try:
proc = subprocess.Popen(command.split(),cwd=builddir,env=env,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-cargofailed',traceback.format_exc())
return o,p,i,compiler,False
assert not err
if out != '':
note(builddir,'warning-cargooutput',out)
if proc.returncode:
note(builddir,'warning-cargofailed','exit code %s' % proc.returncode)
return o,p,i,compiler,False
for analysis in 'execute','valgrind','angr':
shutil.copy('%s/target/release/analysis-%s'%(builddir,analysis),'%s/analysis-%s'%(builddir,analysis))
cargotime += cputime()
notetime(builddir,'cargo',cargotime)
return o,p,i,compiler,True
# ----- compile
compiletime = -cputime()
objfiles = []
for f in files+['analysis-execute.c','analysis-valgrind.c','analysis-angr.c']:
command = '%s -Wall -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler,f)
try:
proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-compilefailed-%s'%f,traceback.format_exc())
return o,p,i,compiler,False
assert not err
if out != '':
note(builddir,'warning-compileoutput-%s'%f,out)
if proc.returncode:
note(builddir,'warning-compilefailed-%s'%f,'exit code %s' % proc.returncode)
return o,p,i,compiler,False
if f in files:
objfiles += ['.'.join(f.split('.')[:-1]+['o'])]
compiletime += cputime()
notetime(builddir,'compile',compiletime)
# ----- link into executable
linktime = -cputime()
for analysis in 'execute','valgrind','angr':
static = True
if static:
command = 'gcc -no-pie -o analysis-%s analysis-%s.o' % (analysis,analysis)
command = command.split()
command += objfiles
else:
command = 'gcc -shared -Wl,-soname,library.so.1 -o library.so.1'
command = command.split()
command += objfiles
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput-%s'%analysis,out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s' % proc.returncode)
return o,p,i,compiler,False
shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir)
command = 'gcc -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir)
command = command.split()
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput-%s'%analysis,out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s' % proc.returncode)
return o,p,i,compiler,False
linktime += cputime()
notetime(builddir,'link',linktime)
return o,p,i,compiler,True
def wanttocompile():
for o,p in primitives:
for i in opimplementations[o,p]:
for compiler in compilerlist:
yield o,p,i,compiler
op_compiled = {}
for o,p in primitives:
op_compiled[o,p] = []
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile()):
if not ok: continue
op_compiled[o,p] += [(i,compiler)]
print('----- execute')
op_x = {}
for o,p in primitives:
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p]
op_x[o,p] = []
for execution in range(numrandomtests):
x = []
for csymbol,symboltype,entries in inputs:
for e in range(entries):
if execution == 0:
r = 0
elif execution == 1:
r = 2**typebits[symboltype]-1
else:
r = random.randrange(2**typebits[symboltype])
x += [r]
op_x[o,p] += [x]
def execute(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p]
executetime = -cputime()
results = []
command = ['./analysis-execute']
for x in op_x[o,p]:
xstr = ''
for r in x: xstr += '%d\n'%r
try:
proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
ystr,err = proc.communicate(input=xstr)
except OSError:
note(builddir,'warning-executeerror',xstr)
return o,p,i,compiler,False
if proc.returncode != 0:
note(builddir,'warning-executefailed',xstr+'exit code %s' % proc.returncode)
return o,p,i,compiler,False
try:
y = [int(s) for s in ystr.splitlines()]
ypos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
assert y[ypos] >= 0
assert y[ypos] < 2**typebits[symboltype]
ypos += 1
assert ypos == len(y)
except ValueError:
note(builddir,'warning-executebadformat',input_example_str(inputs,x)+output_example_str(outputs,y))
return o,p,i,compiler,False
results += [y]
executetime += cputime()
notetime(builddir,'execute',executetime)
return o,p,i,compiler,results
def wanttoexecute():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
yield o,p,i,compiler
opic_y = {}
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute()):
if results == False: continue
opic_y[o,p,i,compiler] = results
print('----- valgrind (can take some time)')
def valgrind(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
valgrindtime = -cputime()
command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind']
valgrindstatus = None
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
valgrindstatus = 'warning-valgrinderror'
if valgrindstatus == None:
assert not err
if proc.returncode == 99:
valgrindstatus = 'unsafe-valgrindfailure'
elif proc.returncode != 0:
valgrindstatus = 'warning-valgrinderror'
elif out.find('client request') >= 0:
valgrindstatus = 'unsafe-valgrindfailure'
if valgrindstatus != None:
note(builddir,valgrindstatus)
valgrindtime += cputime()
notetime(builddir,'valgrind',valgrindtime)
def wanttovalgrind():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
yield o,p,i,compiler
with multiprocessing.Pool(os_cores) as pool:
list(pool.starmap(valgrind,wanttovalgrind()))
print('----- unroll (can take tons of time)')
# XXX: could do this in parallel with valgrind
# XXX: unrolled can be huge; pass through disk instead of RAM
def values(terms,replacements):
# input: replacements mapping cache_key to integers
# output: dictionary V mapping cache_key to pairs (b,i) where i is a b-bit value
# output includes all terms
# _or_ output is None if terms use variables outside replacements
V = {}
W = set() # warnings
def evaluate(t):
if t.cache_key in V:
return True
if t.op == 'BoolV':
V[t.cache_key] = 1,t.args[0]
return True
if t.op == 'BVV':
V[t.cache_key] = t.size(),t.args[0]
return True
if t.op == 'BVS':
if t.cache_key not in replacements: return False
V[t.cache_key] = t.size(),replacements[t.cache_key].args[0]
return True
if t.op == 'Extract':
assert len(t.args) == 3
top = t.args[0]
bot = t.args[1]
if not evaluate(t.args[2]): return False
x0 = V[t.args[2].cache_key]
assert x0[0] > top
assert top >= bot
assert bot >= 0
V[t.cache_key] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot)
return True
if t.op in ('SignExt','ZeroExt'):
assert len(t.args) == 2
if not evaluate(t.args[1]): return False
x0bits,x0 = V[t.args[1].cache_key]
extend = t.args[0]
assert extend >= 0
if t.op == 'SignExt':
if x0 >= (1<<(x0bits-1)):
x0 -= 1<<x0bits
x0 += 1<<(x0bits+extend)
V[t.cache_key] = x0bits+extend,x0
return True
for a in t.args:
if not evaluate(a): return False
x = [V[a.cache_key] for a in t.args]
if t.op == 'Concat':
y = 0
ybits = 0
for xbitsi,xi in x:
y <<= xbitsi
y += xi
ybits += xbitsi
V[t.cache_key] = ybits,y
return True
if t.op in ('__eq__','__ne__'):
assert len(x) == 2
assert x[0][0] == x[1][0]
if t.op == '__eq__': V[t.cache_key] = 1,(x[0][1]==x[1][1])
elif t.op == '__ne__': V[t.cache_key] = 1,(x[0][1]!=x[1][1])
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'):
bits = x[0][0]
assert all(xi[0] == bits for xi in x)
if t.op == '__and__': reduction = (lambda s,t:s&t)
elif t.op == '__or__': reduction = (lambda s,t:s|t)
elif t.op == '__xor__': reduction = (lambda s,t:s^t)
elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits))
elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits))
elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits))
elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits))
elif t.op == '__rshift__':
def reduction(s,t):
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR
usigned = ssigned >> tsigned
return usigned%(2**bits)
elif t.op == '__mul__':
reduction = (lambda s,t:(s*t)%(2**bits))
W.add('mul')
elif t.op == '__floordiv__':
def reduction(s,t):
if t == 0: return 0
return (s//t)%(2**bits)
W.add('div')
elif t.op == '__mod__':
def reduction(s,t):
if t == 0: return s
return (s%t)%(2**bits)
W.add('div')
elif t.op == 'SDiv':
def reduction(s,t):
if t == 0: return 0
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# sdiv definition in Z3:
# - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0].
# - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0].
if ssigned*tsigned >= 0:
usigned = ssigned // tsigned
else:
usigned = -((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
elif t.op == 'SMod':
def reduction(s,t):
if t == 0: return s
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# srem definition in Z3:
# It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division.
# The most significant bit (sign) of the result is equal to the most significant bit of \c t1.
if ssigned*tsigned >= 0:
usigned = ssigned - tsigned*(ssigned // tsigned)
else:
usigned = ssigned + tsigned*((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[t.cache_key] = bits,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == '__invert__':
assert len(x) == 1
bits = x[0][0]
V[t.cache_key] = bits,(1<<bits)-1-x[0][1]
return True
if t.op == 'Not':
assert len(x) == 1
assert all(xi[0] == 1 for xi in x)
V[t.cache_key] = 1,1-x[0][1]
return True
if t.op in ('And','Or'):
assert all(xi[0] == 1 for xi in x)
if t.op == 'And': reduction = (lambda s,t:s*t)
elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t)
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[t.cache_key] = 1,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == 'If':
assert len(x) == 3
assert x[0][0] == 1
if x[0][1]:
V[t.cache_key] = x[1]
else:
V[t.cache_key] = x[2]
return True
if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'):
assert len(x) == 2
bits = x[0][0]
assert bits == x[1][0]
flip = 2**(bits-1)
x0,x1 = x[0][1],x[1][1]
if t.op == '__le__': V[t.cache_key] = (1,x0<=x1)
elif t.op == 'ULE': V[t.cache_key] = (1,x0<=x1)
elif t.op == '__lt__': V[t.cache_key] = (1,x0<x1)
elif t.op == 'ULT': V[t.cache_key] = (1,x0<x1)
elif t.op == '__ge__': V[t.cache_key] = (1,x0>=x1)
elif t.op == 'UGE': V[t.cache_key] = (1,x0>=x1)
elif t.op == '__gt__': V[t.cache_key] = (1,x0>x1)
elif t.op == 'UGT': V[t.cache_key] = (1,x0>x1)
elif t.op == 'SLE': V[t.cache_key] = (1,(x0^flip)<=(x1^flip))
elif t.op == 'SLT': V[t.cache_key] = (1,(x0^flip)<(x1^flip))
elif t.op == 'SGE': V[t.cache_key] = (1,(x0^flip)>=(x1^flip))
elif t.op == 'SGT': V[t.cache_key] = (1,(x0^flip)>(x1^flip))
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
# XXX: add support for more
print('values: unsupported operation %s, falling back to Z3' % t.op)
return False
# XXX: also add more validation for all of the above
for t in terms:
if not evaluate(t): return None,W
return V,W
def unroll_print(outputs,unrolled,f):
walked = {}
def walk(t):
if t.cache_key in walked: return walked[t.cache_key]
if t.op == 'BoolV':
walknext = len(walked)+1
f.write('v%d = bool(%d)\n' % (walknext,t.args[0]))
elif t.op == 'BVV':
walknext = len(walked)+1
f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext = len(walked)+1
f.write('v%d = %s\n' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext = len(walked)+1
f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ['SignExt','ZeroExt']:
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext = len(walked)+1
f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext = len(walked)+1
f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs)))
walked[t.cache_key] = walknext
return walknext
for x in unrolled:
walk(x)
unrolledpos = 0
for csymbol,symboltype,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos])))
unrolledpos += 1
def unroll_inputvars(inputs):
result = []
for csymbol,symboltype,entries in inputs:
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
result += [(varname,variable)]
return result
# XXX: probably better to merge into unroll()
def unroll_worker(binary,inputs,outputs,rust):
results = []
if rust:
proj = angr.Project(binary,auto_load_libs=False)
else:
avoidsimprocedures = (
'memcmp', # want to test the real libc memcmp
)
proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=True)
state = proj.factory.full_init_state()
state.options |= {angr.options.LAZY_SOLVES}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS}
state.options -= {angr.options.SIMPLIFY_EXPRS}
state.options -= {angr.options.SIMPLIFY_REGISTER_WRITES}
state.options -= {angr.options.SIMPLIFY_MEMORY_WRITES}
state.options -= {angr.options.SIMPLIFY_REGISTER_READS}
state.options -= {angr.options.SIMPLIFY_MEMORY_READS}
for csymbol,symboltype,entries in inputs:
xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
if typebits[symboltype] == 8:
state.mem[xaddr+i].char = variable
elif typebits[symboltype] == 16:
state.mem[xaddr+2*i].short = variable
elif typebits[symboltype] == 32:
state.mem[xaddr+4*i].int = variable
elif typebits[symboltype] == 64:
state.mem[xaddr+8*i].long = variable
simgr = proj.factory.simgr(state)
while True:
if len(simgr.errored) > 0:
return -1,False,simgr.errored
if len(simgr.deadended)+len(simgr.active) > maxsplit:
return -1,False,'saferewrite limiting split to %d'%maxsplit
if len(simgr.active) == 0:
break
simgr.step()
exits = simgr.deadended
assert len(exits) > 0
# cannot be safe if there are multiple exits
# for equivalence tests we'll merge exits below
mergedconstraints = []
for epos,e in enumerate(exits):
mergedconstraint = e.solver.true
for c in e.solver.constraints:
mergedconstraint = e.solver.And(mergedconstraint,c)
mergedconstraints += [mergedconstraint]
resultpos = 0
for csymbol,symboltype,entries in outputs:
xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr
for i in range(entries):
if typebits[symboltype] == 8:
xi = e.mem[xaddr+i].char.resolved
elif typebits[symboltype] == 16:
xi = e.mem[xaddr+2*i].short.resolved
elif typebits[symboltype] == 32:
xi = e.mem[xaddr+4*i].int.resolved
elif typebits[symboltype] == 64:
xi = e.mem[xaddr+8*i].long.resolved
if epos == 0:
assert len(results) == resultpos
results += [xi]
else:
results[resultpos] = e.solver.If(mergedconstraint,xi,results[resultpos])
resultpos += 1
assert resultpos == len(results)
assert len(mergedconstraints) == len(exits)
ispartition = True
# are mergedconstraints a partition of all universes?
# i.e.: in each universe, exactly one of the constraints is satisfied?
s = claripy.Solver(timeout=z3timeout)
for c in mergedconstraints:
s.add(claripy.Not(c))
if s.satisfiable():
ispartition = False
for i in range(len(exits)):
for j in range(i):
s = claripy.Solver(timeout=z3timeout)
s.add(mergedconstraints[i])
s.add(mergedconstraints[j])
if s.satisfiable():
ispartition = False
return len(exits),ispartition,results
def unroll(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p]
rust = 'Cargo.toml' in os.listdir(implementationdir)
unrolltime = -cputime()
numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs,rust)
if numexits < 1:
note(builddir,'warning-unrollerror',unrolled)
return o,p,i,compiler,False
if not ispartition:
note(builddir,'warning-unrollnotpartition')
return o,p,i,compiler,False
if numexits > 1:
note(builddir,'unsafe-unrollsplit-%d'%numexits)
with open('%s/analysis/unrolled' % builddir,'w') as f:
unroll_print(outputs,unrolled,f)
okvars = set(vname for vname,v in unroll_inputvars(inputs))
usedvars = set(v for x in unrolled for v in x.variables)
if not usedvars.issubset(okvars):
note(builddir,'warning-unrollmem')
if not okvars.issubset(usedvars):
note(builddir,'warning-unusedinputs')
Wdone = set()
for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]):
# cpu gave us outputs y given inputs x
# does this match unrolled?
replacements = {}
xpos = 0
for csymbol,symboltype,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
replacements[variable.cache_key] = claripy.BVV(x[xpos],typebits[symboltype])
xpos += 1
assert xpos == len(x)
mismatch = True
notestr = 'internal error\n'
V = None
try:
V,W = values(unrolled,replacements)
for warn in W:
if warn in Wdone: continue
Wdone.add(warn)
note(builddir,'warning-%s'%warn)
except AssertionError:
note(builddir,'warning-valuesfailed',traceback.format_exc())
# proceed with z3 fallback below
if V != None:
mismatch = any(yi != V[unrolledi.cache_key][1] for (yi,unrolledi) in zip(y,unrolled))
if mismatch:
notestr = '# mismatch between CPU execution of binary and saferewrite execution of unrolled:\n'
for (vname,v),xi in zip(unroll_inputvars(inputs),x):
notestr += '%s = %d\n' % (vname,xi)
pos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
notestr += 'executed_%s = %s\n' % (varname,y[pos])
notestr += 'unrolled_%s = %s\n' % (varname,V[unrolled[pos].cache_key][1])
pos += 1
else:
# fall back on Z3 for figuring this out
s = claripy.Solver(timeout=z3timeout)
mismatch = claripy.false
for yi,unrolledi in zip(y,unrolled):
mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi)
s.add(mismatch)
mismatch = s.satisfiable()
if mismatch:
notestr = '# hint: this type of mismatch typically reflects reading undefined memory.\n'
notestr = '# mismatch between CPU execution of binary and Z3 execution of unrolled:\n'
for vname,v in unroll_inputvars(inputs):
notestr += '%s = %s\n' % (vname,s.eval(v,1)[0])
pos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
notestr += 'executed_%s = %s\n' % (varname,y[pos])
notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0])
pos += 1
if mismatch:
note(builddir,'warning-unrollmismatch',notestr)
return o,p,i,compiler,False
unrolltime += cputime()
notetime(builddir,'unroll',unrolltime)
return o,p,i,compiler,unrolled
def wanttounroll():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
if (o,p,i,compiler) in opic_y:
yield o,p,i,compiler
opic_unrolled = {}
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,unrolled in pool.starmap(unroll,wanttounroll()):
if unrolled == False: continue
opic_unrolled[o,p,i,compiler] = unrolled
print('----- compareunrolled (can take tons of time)')
def compareunrolled(o,p,i,compiler,source,sourcecompiler):
compilerword = compiler.replace(' ','_').replace('=','_')
sourcecompilerword = sourcecompiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p]
randomtestspass = True
for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])):
if y != z:
xstr = input_example_str(inputs,x)
note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourcecompilerword),xstr)
randomtestspass = False
# continue through random tests to get statistics
if not randomtestspass:
if not satvalidation1:
return
equivtime = -cputime()
u1 = opic_unrolled[o,p,source,sourcecompiler]
u2 = opic_unrolled[o,p,i,compiler]
assert len(u1) == len(u2)
# XXX: allow other equivalence-testing techniques
s = claripy.Solver(timeout=z3timeout)
different = claripy.false
for u1j,u2j in zip(u1,u2):
different = claripy.Or(different,u1j != u2j)
s.add(different)
try:
mismatch = s.satisfiable()
except claripy.errors.ClaripyZ3Error:
# avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887
note(builddir,'warning-z3failed',traceback.format_exc())
return
if mismatch:
# angr documentation says:
# "If you don't add any constraints between two queries, the results will be consistent with each other."
example = ''
for vname,v in unroll_inputvars(inputs):
example += '%s = %s\n' % (vname,s.eval(v,1)[0])
unrolledpos = 0
for csymbol,symboltype,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0])
example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0])
unrolledpos += 1
note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourcecompilerword),example)
else:
note(builddir,'equals-%s-%s' % (source,sourcecompilerword))
equivtime += cputime()
notetime(builddir,'equiv',equivtime)
def wanttocompareunrolled():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
source = 'ref' # XXX: allow each implementation to choose source
if i == 'ref':
sourcecompiler = compilerlist[0] # XXX: maybe also allow choice
else:
sourcecompiler = compiler
if (o,p,i,compiler) not in opic_unrolled: continue
if (o,p,source,sourcecompiler) not in opic_unrolled: continue
# XXX: could also do self-tests
if (o,p,source,sourcecompiler) == (o,p,i,compiler): continue
yield o,p,i,compiler,source,sourcecompiler
with multiprocessing.Pool(os_cores) as pool:
list(pool.starmap(compareunrolled,wanttocompareunrolled()))