-rwxr-xr-x 31359 saferewrite-20210904/analyze
#!/usr/bin/env python3 compilerlist = ( 'clang -O1 -fwrapv -march=native', 'gcc -O3 -march=native -mtune=native', ) numrandomtests = 16 avoidsimprocedures = ( 'memcmp', # we want to test the real libc memcmp ) typebits = { 'int8': 8, 'int16': 16, 'int32': 32, 'int64': 64, } import sys import os import shutil import subprocess import angr import claripy import multiprocessing import random import traceback import functools try: os_cores = len(os.sched_getaffinity(0)) except AttributeError: os_cores = multiprocessing.cpu_count() os_cores = os.getenv('CORES',default=os_cores) os_cores = int(os_cores) if os_cores < 1: os_cores = 1 import resource def cputime(): return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime def notetime(builddir,what,time): print('%s seconds %s %.6f' % (builddir,what,time)) sys.stdout.flush() with open('%s/analysis/seconds' % builddir,'a') as f: f.write('%s %.6f\n' % (what,time)) def note(builddir,conclusion,contents=None): print('%s %s' % (builddir,conclusion)) sys.stdout.flush() with open('%s/analysis/%s' % (builddir,conclusion),'w') as f: if contents != None: f.write(str(contents)) sys.setrecursionlimit(1000000) startdir = os.getcwd() assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir) shutil.rmtree('build',ignore_errors=True) os.makedirs('build') primitives = [] for o in 'src',: o = o.strip() if o == '': continue if not os.path.isdir(o): continue if os.stat('%s' % o).st_mode & 0o1000 == 0o1000: print('%s sticky, skipping' % o) sys.stdout.flush() continue for p in sorted(os.listdir(o)): if not os.path.isdir('%s/%s' % (o,p)): continue if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000: print('%s/%s sticky, skipping' % (o,p)) sys.stdout.flush() continue if not os.path.exists('%s/%s/api' % (o,p)): print('%s/%s/api nonexistent, skipping' % (o,p)) sys.stdout.flush() continue primitives += [(o,p)] op_api = {} for o,p in primitives: inputs = [] outputs = [] funargs = [] funargtypes = [] funname = None funret = None funrettype = 'void' with open('%s/%s/api' % (o,p)) as f: for line in f: line = line.split() if len(line) == 0: continue if line[0] == 'call': funname = line[1] if line[0] == 'return': bitsperentry = typebits[line[1]] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) entries = 1 outputs += [(csymbol,bitsperentry,entries)] funret = 'alloc_%s'%csymbol funrettype = 'uint%d_t'%bitsperentry if line[0] in ('in','out','inout'): bitsperentry = typebits[line[1]] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) if len(line) == 3: pointer = False entries = 1 else: pointer = True entries = int(line[3]) if line[0] in ('in','inout'): inputs += [(csymbol,bitsperentry,entries)] if line[0] in ('out','inout'): outputs += [(csymbol,bitsperentry,entries)] if pointer: funargs += ['alloc_%s'%csymbol] funargtypes += ['uint%d_t *' % bitsperentry] else: funargs += ['*alloc_%s'%csymbol] funargtypes += ['uint%d_t' % bitsperentry] # XXX: support constant inputs op_api[o,p] = inputs,outputs,funargs,funargtypes,funname,funret,funrettype def input_example_str(inputs,x): xstr = '' xpos = 0 for csymbol,bitsperentry,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) xstr += '%s = %d\n' % (varname,x[xpos]) xpos += 1 assert xpos == len(x) return xstr def output_example_str(outputs,y): ystr = '' ypos = 0 for csymbol,bitsperentry,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) ystr += '%s = %d\n' % (varname,y[ypos]) ypos += 1 assert ypos == len(y) return ystr reservedfilenames = ( 'library.so.1', 'analysis', 'analysis-execute', 'analysis-execute.c', 'analysis-valgrind', 'analysis-valgrind.c', 'analysis-angr', 'analysis-angr.c', ) opimplementations = {} for o,p in primitives: opimplementations[o,p] = [] for i in sorted(os.listdir('%s/%s' % (o,p))): implementationdir = '%s/%s/%s' % (o,p,i) if not os.path.isdir(implementationdir): continue if os.stat(implementationdir).st_mode & 0o1000 == 0o1000: print('%s/%s/%s sticky, skipping' % (o,p,i)) continue files = sorted(os.listdir(implementationdir)) for f in files: ok = True if f in reservedfilenames: print('%s/%s/%s/%s reserved filename' % (o,p,i,f)) ok = False if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f): print('%s/%s/%s/%s prohibited character' % (o,p,i,f)) ok = False if not ok: continue opimplementations[o,p] += [i] for compiler in compilerlist: compilerword = compiler.replace(' ','_').replace('=','_') builddir = 'build/%s/%s/%s' % (p,i,compilerword) os.makedirs('build/%s/%s' % (p,i),exist_ok=True) def compile(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p] files = sorted(os.listdir(implementationdir)) cfiles = [x for x in files if x.endswith('.c')] sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')] files = cfiles + sfiles shutil.copytree(implementationdir,builddir) os.makedirs('%s/analysis' % builddir) for bits in 8,16,32,64: with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_int%d int%d_t' % (bits,bits)) with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_uint%d uint%d_t' % (bits,bits)) for analysis in 'execute','valgrind','angr': with open('%s/analysis-%s.c' % (builddir,analysis),'w') as f: f.write('#include <stdio.h>\n') f.write('#include <stdlib.h>\n') f.write('#include <string.h>\n') f.write('#include <inttypes.h>\n') f.write('\n') # function declaration f.write('extern ') if funrettype != None: f.write('%s ' % funrettype) f.write('%s(%s);\n' % (funname,','.join(funargtypes))) f.write('\n') for csymbol,bitsperentry,entries in inputs+outputs: f.write('uint%d_t static_%s[%d];\n' % (bitsperentry,csymbol,entries)) f.write('\n') f.write('int main(int argc,char **argv)\n') f.write('{\n') for csymbol,bitsperentry,entries in inputs: f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (bitsperentry,csymbol,entries*bitsperentry/8)) for csymbol,bitsperentry,entries in outputs: if (csymbol,bitsperentry,entries) not in inputs: f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (bitsperentry,csymbol,entries*bitsperentry/8)) f.write('\n') # XXX: resource limits if analysis == 'execute': for csymbol,bitsperentry,entries in inputs: f.write(' for (long long i = 0;i < %d;++i) {\n' % entries) f.write(' unsigned long long x;\n') f.write(' if (scanf("%llu",&x) != 1) abort();\n') f.write(' static_%s[i] = x;\n' % csymbol) f.write(' }\n') f.write('\n') if analysis in ('execute','angr'): for csymbol,bitsperentry,entries in inputs: f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' alloc_%s[i] = static_%s[i];\n' % (csymbol,csymbol)) f.write('\n') f.write(' ') if funret != None: f.write('%s[0] = ' % funret) f.write('%s(%s);\n' % (funname,','.join(funargs))) f.write('\n') if analysis in ('execute','angr'): for csymbol,bitsperentry,entries in outputs: f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' static_%s[i] = alloc_%s[i];\n' % (csymbol,csymbol)) f.write('\n') if analysis == 'execute': for csymbol,bitsperentry,entries in outputs: f.write(' for (long long i = 0;i < %d;++i) {\n' % entries) f.write(' unsigned long long x = static_%s[i];\n' % csymbol) f.write(' printf("%llu\\n",x);\n') f.write(' }\n') f.write(' fflush(stdout);\n') f.write('\n') f.write(' return 0;\n') f.write('}\n') # ----- compile compiletime = -cputime() objfiles = [] for f in files+['analysis-execute.c','analysis-valgrind.c','analysis-angr.c']: command = '%s -Wall -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler,f) try: proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-compilefailed',traceback.format_exc()) return o,p,i,compiler,False assert not err if out != '': note(builddir,'warning-compileoutput',out) if proc.returncode: note(builddir,'warning-compilefailed','exit code %s' % proc.returncode) return o,p,i,compiler,False if f in files: objfiles += ['.'.join(f.split('.')[:-1]+['o'])] compiletime += cputime() notetime(builddir,'compile',compiletime) # ----- link into executable linktime = -cputime() for analysis in 'execute','valgrind','angr': static = True if static: command = 'gcc -no-pie -o analysis-%s analysis-%s.o' % (analysis,analysis) command = command.split() command += objfiles else: command = 'gcc -shared -Wl,-soname,library.so.1 -o library.so.1' command = command.split() command += objfiles try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-linkfailed',traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput',out) assert not err if proc.returncode: note(builddir,'warning-linkfailed','exit code %s' % proc.returncode) return o,p,i,compiler,False shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir) command = 'gcc -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir) command = command.split() try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-linkfailed',traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput',out) assert not err if proc.returncode: note(builddir,'warning-linkfailed','exit code %s' % proc.returncode) return o,p,i,compiler,False linktime += cputime() notetime(builddir,'link',linktime) return o,p,i,compiler,True def wanttocompile(): for o,p in primitives: for i in opimplementations[o,p]: for compiler in compilerlist: yield o,p,i,compiler op_compiled = {} for o,p in primitives: op_compiled[o,p] = [] with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile()): if not ok: continue op_compiled[o,p] += [(i,compiler)] print('----- execute') op_x = {} for o,p in primitives: inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p] op_x[o,p] = [] for execution in range(numrandomtests): x = [] for csymbol,bitsperentry,entries in inputs: for e in range(entries): if execution == 0: r = 0 elif execution == 1: r = 2**bitsperentry-1 else: r = random.randrange(2**bitsperentry) x += [r] op_x[o,p] += [x] def execute(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p] executetime = -cputime() results = [] command = ['./analysis-execute'] for x in op_x[o,p]: xstr = '' for r in x: xstr += '%d\n'%r try: proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) ystr,err = proc.communicate(input=xstr) except OSError: note(builddir,'warning-executeerror',xstr) return o,p,i,compiler,False if proc.returncode != 0: note(builddir,'warning-executefailed',xstr+'exit code %s' % proc.returncode) return o,p,i,compiler,False try: y = [int(s) for s in ystr.splitlines()] ypos = 0 for csymbol,bitsperentry,entries in outputs: for e in range(entries): assert y[ypos] >= 0 assert y[ypos] < 2**bitsperentry ypos += 1 assert ypos == len(y) except ValueError: note(builddir,'warning-executebadformat',input_example_str(inputs,x)+output_example_str(outputs,y)) return o,p,i,compiler,False results += [y] executetime += cputime() notetime(builddir,'execute',executetime) return o,p,i,compiler,results def wanttoexecute(): for o,p in primitives: for i,compiler in op_compiled[o,p]: yield o,p,i,compiler opic_y = {} with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute()): if results == False: continue opic_y[o,p,i,compiler] = results print('----- valgrind (can take some time)') def valgrind(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) valgrindtime = -cputime() command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind'] valgrindstatus = None try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: valgrindstatus = 'warning-valgrinderror' if valgrindstatus == None: assert not err if proc.returncode == 99: valgrindstatus = 'unsafe-valgrindfailure' elif proc.returncode != 0: valgrindstatus = 'warning-valgrinderror' elif out.find('client request') >= 0: valgrindstatus = 'unsafe-valgrindfailure' if valgrindstatus != None: note(builddir,valgrindstatus) valgrindtime += cputime() notetime(builddir,'valgrind',valgrindtime) def wanttovalgrind(): for o,p in primitives: for i,compiler in op_compiled[o,p]: yield o,p,i,compiler with multiprocessing.Pool(os_cores) as pool: list(pool.starmap(valgrind,wanttovalgrind())) print('----- unroll (can take tons of time)') # XXX: could do this in parallel with valgrind # XXX: unrolled can be huge; pass through disk instead of RAM def values(terms,replacements): # input: replacements mapping cache_key to integers # output: dictionary V mapping cache_key to pairs (b,i) where i is a b-bit value # output includes all terms # _or_ output is None if terms use variables outside replacements V = {} def evaluate(t): if t.cache_key in V: return True if t.op == 'BoolV': V[t.cache_key] = 1,t.args[0] return True if t.op == 'BVV': V[t.cache_key] = t.size(),t.args[0] return True if t.op == 'BVS': if t.cache_key not in replacements: return False V[t.cache_key] = t.size(),replacements[t.cache_key].args[0] return True if t.op == 'Extract': assert len(t.args) == 3 top = t.args[0] bot = t.args[1] if not evaluate(t.args[2]): return False x0 = V[t.args[2].cache_key] assert x0[0] > top assert top >= bot assert bot >= 0 V[t.cache_key] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot) return True if t.op in ('SignExt','ZeroExt'): assert len(t.args) == 2 if not evaluate(t.args[1]): return False x0bits,x0 = V[t.args[1].cache_key] extend = t.args[0] assert extend >= 0 if t.op == 'SignExt': if x0 >= (1<<(x0bits-1)): x0 -= 1<<x0bits x0 += 1<<(x0bits+extend) V[t.cache_key] = x0bits+extend,x0 return True for a in t.args: if not evaluate(a): return False x = [V[a.cache_key] for a in t.args] if t.op == 'Concat': y = 0 ybits = 0 for xbitsi,xi in x: y <<= xbitsi y += xi ybits += xbitsi V[t.cache_key] = ybits,y return True if t.op in ('__eq__','__ne__'): assert len(x) == 2 assert x[0][0] == x[1][0] if t.op == '__eq__': V[t.cache_key] = 1,(x[0][1]==x[1][1]) elif t.op == '__ne__': V[t.cache_key] = 1,(x[0][1]==x[1][1]) else: return False return True if t.op in ('__add__','__mul__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'): bits = x[0][0] assert all(xi[0] == bits for xi in x) if t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits)) elif t.op == '__mul__': reduction = (lambda s,t:(s*t)%(2**bits)) elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits)) elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits)) elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits)) elif t.op == '__rshift__': def reduction(s,t): flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip assert 0 <= tsigned assert tsigned < bits usigned = ssigned >> tsigned return (usigned + flip) ^ flip elif t.op == '__and__': reduction = (lambda s,t:s&t) elif t.op == '__or__': reduction = (lambda s,t:s|t) elif t.op == '__xor__': reduction = (lambda s,t:s^t) else: return False V[t.cache_key] = bits,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == '__invert__': assert len(x) == 1 bits = x[0][0] V[t.cache_key] = bits,(1<<bits)-1-x[0][1] return True if t.op == 'Not': assert len(x) == 1 assert all(xi[0] == 1 for xi in x) V[t.cache_key] = 1,1-x[0][1] return True if t.op in ('And','Or'): assert all(xi[0] == 1 for xi in x) if t.op == 'And': reduction = (lambda s,t:s*t) elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t) else: return False V[t.cache_key] = 1,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == 'If': assert len(x) == 3 assert x[0][0] == 1 if x[0][1]: V[t.cache_key] = x[1] else: V[t.cache_key] = x[2] return True if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'): assert len(x) == 2 bits = x[0][0] assert bits == x[1][0] flip = 2**(bits-1) x0,x1 = x[0][1],x[1][1] if t.op == '__le__': V[t.cache_key] = (1,x0<=x1) elif t.op == 'ULE': V[t.cache_key] = (1,x0<=x1) elif t.op == '__lt__': V[t.cache_key] = (1,x0<x1) elif t.op == 'ULT': V[t.cache_key] = (1,x0<x1) elif t.op == '__ge__': V[t.cache_key] = (1,x0>=x1) elif t.op == 'UGE': V[t.cache_key] = (1,x0>=x1) elif t.op == '__gt__': V[t.cache_key] = (1,x0>x1) elif t.op == 'UGT': V[t.cache_key] = (1,x0>x1) elif t.op == 'SLE': V[t.cache_key] = (1,(x0^flip)<=(x1^flip)) elif t.op == 'SLT': V[t.cache_key] = (1,(x0^flip)<(x1^flip)) elif t.op == 'SGE': V[t.cache_key] = (1,(x0^flip)>=(x1^flip)) elif t.op == 'SGT': V[t.cache_key] = (1,(x0^flip)>(x1^flip)) else: return False return True # XXX: add support for more print('values: unsupported operation %s, falling back to Z3' % t.op) return False # XXX: also add more validation for all of the above for t in terms: if not evaluate(t): return None return V def unroll_print(outputs,unrolled,f): walked = {} def walk(t): if t in walked: return walked[t] if t.op == 'BoolV': walknext = len(walked)+1 f.write('v%d = bool(%d)\n' % (walknext,t.args[0])) elif t.op == 'BVV': walknext = len(walked)+1 f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0])) elif t.op == 'BVS': walknext = len(walked)+1 f.write('v%d = %s\n' % (walknext,t.args[0])) elif t.op == 'Extract': assert len(t.args) == 3 input = 'v%d' % walk(t.args[2]) walknext = len(walked)+1 f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1])) elif t.op in ['SignExt','ZeroExt']: assert len(t.args) == 2 input = 'v%d' % walk(t.args[1]) walknext = len(walked)+1 f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0])) else: inputs = ['v%d' % walk(a) for a in t.args] walknext = len(walked)+1 f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs))) walked[t] = walknext return walknext for x in unrolled: walk(x) unrolledpos = 0 for csymbol,bitsperentry,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos]))) unrolledpos += 1 def unroll_inputvars(inputs): result = [] for csymbol,bitsperentry,entries in inputs: for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,bitsperentry,explicit_name=True) result += [(varname,variable)] return result # XXX: probably better to merge into unroll() def unroll_worker(binary,inputs,outputs): results = [] proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures) state = proj.factory.full_init_state() state.options |= {angr.options.LAZY_SOLVES} state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY} state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS} state.options -= {angr.options.SIMPLIFY_EXPRS} state.options -= {angr.options.SIMPLIFY_REGISTER_WRITES} state.options -= {angr.options.SIMPLIFY_MEMORY_WRITES} state.options -= {angr.options.SIMPLIFY_REGISTER_READS} state.options -= {angr.options.SIMPLIFY_MEMORY_READS} for csymbol,bitsperentry,entries in inputs: xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,bitsperentry,explicit_name=True) if bitsperentry == 8: state.mem[xaddr+i].char = variable elif bitsperentry == 16: state.mem[xaddr+2*i].short = variable elif bitsperentry == 32: state.mem[xaddr+4*i].int = variable elif bitsperentry == 64: state.mem[xaddr+8*i].long = variable simgr = proj.factory.simgr(state) simgr.run() if len(simgr.errored) > 0: return -1,False,simgr.errored exits = simgr.deadended assert len(exits) > 0 # cannot be safe if there are multiple exits # for equivalence tests we'll merge exits below mergedconstraints = [] for epos,e in enumerate(exits): mergedconstraint = e.solver.true for c in e.solver.constraints: mergedconstraint = e.solver.And(mergedconstraint,c) mergedconstraints += [mergedconstraint] resultpos = 0 for csymbol,bitsperentry,entries in outputs: xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr for i in range(entries): if bitsperentry == 8: xi = e.mem[xaddr+i].char.resolved elif bitsperentry == 16: xi = e.mem[xaddr+2*i].short.resolved elif bitsperentry == 32: xi = e.mem[xaddr+4*i].int.resolved elif bitsperentry == 64: xi = e.mem[xaddr+8*i].long.resolved if epos == 0: assert len(results) == resultpos results += [xi] else: results[resultpos] = e.solver.If(mergedconstraint,xi,results[resultpos]) resultpos += 1 assert resultpos == len(results) assert len(mergedconstraints) == len(exits) ispartition = True # are mergedconstraints a partition of all universes? # i.e.: in each universe, exactly one of the constraints is satisfied? s = claripy.Solver() for c in mergedconstraints: s.add(claripy.Not(c)) if s.satisfiable(): ispartition = False for i in range(len(exits)): for j in range(i): s = claripy.Solver() s.add(mergedconstraints[i]) s.add(mergedconstraints[j]) if s.satisfiable(): ispartition = False return len(exits),ispartition,results def unroll(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p] unrolltime = -cputime() numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs) if numexits < 1: note(builddir,'warning-unrollerror',unrolled) return o,p,i,compiler,False if not ispartition: note(builddir,'warning-unrollnotpartition') return o,p,i,compiler,False if numexits > 1: note(builddir,'unsafe-unrollsplit-%d'%numexits) with open('%s/analysis/unrolled' % builddir,'w') as f: unroll_print(outputs,unrolled,f) okvars = set(vname for vname,v in unroll_inputvars(inputs)) usedvars = set(v for x in unrolled for v in x.variables) if not usedvars.issubset(okvars): note(builddir,'warning-unrollmem') if not okvars.issubset(usedvars): note(builddir,'warning-unusedinputs') for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]): # cpu gave us outputs y given inputs x # does this match unrolled? replacements = {} xpos = 0 for csymbol,bitsperentry,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) variable = claripy.BVS(varname,bitsperentry,explicit_name=True) replacements[variable.cache_key] = claripy.BVV(x[xpos],bitsperentry) xpos += 1 assert xpos == len(x) V = None try: V = values(unrolled,replacements) except AssertionError: note(builddir,'warning-valuesfailed',traceback.format_exc()) # proceed with z3 fallback below if V != None: mismatch = all(yi == V[unrolledi.cache_key] for (yi,unrolledi) in zip(y,unrolled)) else: # fall back on Z3 for figuring this out s = claripy.Solver() mismatch = claripy.false for yi,unrolledi in zip(y,unrolled): mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi) s.add(mismatch) mismatch = s.satisfiable() if mismatch: notestr = '' for vname,v in unroll_inputvars(inputs): notestr += '%s = %s\n' % (vname,s.eval(v,1)[0]) pos = 0 for csymbol,bitsperentry,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) notestr += 'executed_%s = %s\n' % (varname,y[pos]) notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0]) pos += 1 note(builddir,'warning-unrollmismatch',notestr) return o,p,i,compiler,False unrolltime += cputime() notetime(builddir,'unroll',unrolltime) return o,p,i,compiler,unrolled def wanttounroll(): for o,p in primitives: for i,compiler in op_compiled[o,p]: if (o,p,i,compiler) in opic_y: yield o,p,i,compiler opic_unrolled = {} with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,unrolled in pool.starmap(unroll,wanttounroll()): if unrolled == False: continue opic_unrolled[o,p,i,compiler] = unrolled print('----- compareunrolled (can take tons of time)') def compareunrolled(o,p,i,compiler,source,sourcecompiler): compilerword = compiler.replace(' ','_').replace('=','_') sourcecompilerword = sourcecompiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p] for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])): if y != z: xstr = input_example_str(inputs,x) note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourcecompilerword),xstr) # could return at this point to save time # but to help validate symbolic testing we also want to see symbolic testing fail equivtime = -cputime() u1 = opic_unrolled[o,p,source,sourcecompiler] u2 = opic_unrolled[o,p,i,compiler] assert len(u1) == len(u2) # XXX: allow other equivalence-testing techniques s = claripy.Solver() different = claripy.false for u1j,u2j in zip(u1,u2): different = claripy.Or(different,u1j != u2j) s.add(different) try: mismatch = s.satisfiable() except claripy.errors.ClaripyZ3Error: # avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887 note(builddir,'warning-z3failed',traceback.format_exc()) return if mismatch: # angr documentation says: # "If you don't add any constraints between two queries, the results will be consistent with each other." example = '' for vname,v in unroll_inputvars(inputs): example += '%s = %s\n' % (vname,s.eval(v,1)[0]) unrolledpos = 0 for csymbol,bitsperentry,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0]) example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0]) unrolledpos += 1 note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourcecompilerword),example) else: note(builddir,'equals-%s-%s' % (source,sourcecompilerword)) equivtime += cputime() notetime(builddir,'equiv',equivtime) def wanttocompareunrolled(): for o,p in primitives: for i,compiler in op_compiled[o,p]: source = 'ref' # XXX: allow each implementation to choose source if i == 'ref': sourcecompiler = compilerlist[0] # XXX: maybe also allow choice else: sourcecompiler = compiler if (o,p,i,compiler) not in opic_unrolled: continue if (o,p,source,sourcecompiler) not in opic_unrolled: continue # XXX: could also do self-tests if (o,p,source,sourcecompiler) == (o,p,i,compiler): continue yield o,p,i,compiler,source,sourcecompiler with multiprocessing.Pool(os_cores) as pool: list(pool.starmap(compareunrolled,wanttocompareunrolled()))