-rwxr-xr-x 41356 saferewrite-20211125/analyze
#!/usr/bin/env python3 compilerlist = ( 'clang -O1 -fwrapv -march=native', 'gcc -O3 -march=native -mtune=native', ) typebits = { 'int8': 8, 'int16': 16, 'int32': 32, 'int64': 64, 'uint8': 8, 'uint16': 16, 'uint32': 32, 'uint64': 64, } rusttype = { 'int8': 'i8', 'int16': 'i16', 'int32': 'i32', 'int64': 'i64', 'uint8': 'u8', 'uint16': 'u16', 'uint32': 'u32', 'uint64': 'u64', } import sys import os import shutil import subprocess import angr import claripy import multiprocessing import random import traceback import functools # ----- performance-related configuration... numrandomtests = 16 # test 16 random inputs to each function # (first two are all-0, all-1) # of course more time on testing and fuzzing would be useful maxsplit = 100 # max number of universes within an angr run # XXX: allow per-primitive/per-implementation configuration satvalidation1 = False # True would mean: even if random tests fail, invoke general sat/unsat mechanism # this is for extra validation of that mechanism claripy.simplifications.simpleton._simplifiers = { 'Reverse': claripy.simplifications.simpleton.bv_reverse_simplifier, 'Extract': claripy.simplifications.simpleton.extract_simplifier, 'Concat': claripy.simplifications.simpleton.concat_simplifier, 'ZeroExt': claripy.simplifications.simpleton.zeroext_simplifier, 'SignExt': claripy.simplifications.simpleton.signext_simplifier, } # this is somewhat ad-hoc monkey-patching # underlying problem: flatten simplifiers can easily blow up sys.setrecursionlimit(1000000) # ----- real work begins here z3timeout = 2**32-1 # XXX: by default claripy satisfiable() has 5-minute timeout # and indicates timeout by returning False # current Z3 documentation says: # timeout (in milliseconds) (UINT_MAX and 0 mean no timeout) try: os_cores = len(os.sched_getaffinity(0)) except AttributeError: os_cores = multiprocessing.cpu_count() os_cores = os.getenv('CORES',default=os_cores) os_cores = int(os_cores) if os_cores < 1: os_cores = 1 import resource def cputime(): return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime def notetime(builddir,what,time): print('%s seconds %s %.6f' % (builddir,what,time)) sys.stdout.flush() with open('%s/analysis/seconds' % builddir,'a') as f: f.write('%s %.6f\n' % (what,time)) def note(builddir,conclusion,contents=None): print('%s %s' % (builddir,conclusion)) sys.stdout.flush() with open('%s/analysis/%s' % (builddir,conclusion),'w') as f: if contents != None: f.write(str(contents)) startdir = os.getcwd() assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir) shutil.rmtree('build',ignore_errors=True) os.makedirs('build') primitives = [] for o in 'src',: o = o.strip() if o == '': continue if not os.path.isdir(o): continue if os.stat('%s' % o).st_mode & 0o1000 == 0o1000: print('%s sticky, skipping' % o) sys.stdout.flush() continue for p in sorted(os.listdir(o)): if not os.path.isdir('%s/%s' % (o,p)): continue if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000: print('%s/%s sticky, skipping' % (o,p)) sys.stdout.flush() continue if not os.path.exists('%s/%s/api' % (o,p)): print('%s/%s/api nonexistent, skipping' % (o,p)) sys.stdout.flush() continue primitives += [(o,p)] op_api = {} for o,p in primitives: inputs = [] outputs = [] funargs = [] funargtypes = [] rustargs = [] rustargtypes = [] crustargs = [] crustargtypes = [] funname = None funret = None funrettype = 'void' with open('%s/%s/api' % (o,p)) as f: for line in f: line = line.split() if len(line) == 0: continue if line[0] == 'call': funname = line[1] assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_' for x in funname) if line[0] == 'return': symboltype = line[1] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) entries = 1 outputs += [(csymbol,symboltype,entries)] funret = 'alloc_%s'%csymbol funrettype = 'uint%d_t'%typebits[symboltype] if line[0] in ('in','out','inout'): symboltype = line[1] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) if len(line) == 3: pointer = False entries = 1 else: pointer = True entries = int(line[3]) if line[0] in ('in','inout'): inputs += [(csymbol,symboltype,entries)] if line[0] in ('out','inout'): outputs += [(csymbol,symboltype,entries)] if pointer: funargs += ['alloc_%s'%csymbol] funargtypes += ['uint%d_t *' % typebits[symboltype]] crustargs += ['c_%s'%csymbol] rustargs += ['rust_%s'%csymbol] if line[0] == 'in': crustargtypes += ['*const %s' % rusttype[symboltype]] else: crustargtypes += ['*mut %s' % rusttype[symboltype]] else: funargs += ['*alloc_%s'%csymbol] funargtypes += ['uint%d_t' % typebits[symboltype]] crustargs += ['c_%s'%csymbol] rustargs += ['rust_%s'%csymbol] crustargtypes += ['const %s' % rusttype[symboltype]] # XXX: support constant inputs op_api[o,p] = inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype def input_example_str(inputs,x): xstr = '' xpos = 0 for csymbol,symboltype,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) xstr += '%s = %d\n' % (varname,x[xpos]) xpos += 1 assert xpos == len(x) return xstr def output_example_str(outputs,y): ystr = '' ypos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) ystr += '%s = %d\n' % (varname,y[ypos]) ypos += 1 assert ypos == len(y) return ystr reservedfilenames = ( 'library.so.1', 'analysis', 'analysis-execute', 'analysis-execute.c', 'analysis-valgrind', 'analysis-valgrind.c', 'analysis-angr', 'analysis-angr.c', ) opimplementations = {} for o,p in primitives: opimplementations[o,p] = [] for i in sorted(os.listdir('%s/%s' % (o,p))): implementationdir = '%s/%s/%s' % (o,p,i) if not os.path.isdir(implementationdir): continue if os.stat(implementationdir).st_mode & 0o1000 == 0o1000: print('%s/%s/%s sticky, skipping' % (o,p,i)) continue files = sorted(os.listdir(implementationdir)) for f in files: ok = True if f in reservedfilenames: print('%s/%s/%s/%s reserved filename' % (o,p,i,f)) ok = False if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f): print('%s/%s/%s/%s prohibited character' % (o,p,i,f)) ok = False if not ok: continue opimplementations[o,p] += [i] for compiler in compilerlist: compilerword = compiler.replace(' ','_').replace('=','_') builddir = 'build/%s/%s/%s' % (p,i,compilerword) os.makedirs('build/%s/%s' % (p,i),exist_ok=True) def compile(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p] files = sorted(os.listdir(implementationdir)) rust = 'Cargo.toml' in files rustlocked = 'Cargo.lock' in files cfiles = [x for x in files if x.endswith('.c')] sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')] files = cfiles + sfiles shutil.copytree(implementationdir,builddir) os.makedirs('%s/analysis' % builddir) for bits in 8,16,32,64: with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_int%d int%d_t' % (bits,bits)) with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_uint%d uint%d_t' % (bits,bits)) for analysis in 'execute','valgrind','angr': with open('%s/analysis-%s.c' % (builddir,analysis),'w') as f: f.write('#include <stdio.h>\n') f.write('#include <stdlib.h>\n') f.write('#include <string.h>\n') f.write('#include <inttypes.h>\n') f.write('\n') # function declaration f.write('extern ') if funrettype != None: f.write('%s ' % funrettype) f.write('%s(%s);\n' % (funname,','.join(funargtypes))) f.write('\n') for csymbol,symboltype,entries in inputs: f.write('uint%d_t static_%s[%d];\n' % (typebits[symboltype],csymbol,entries)) for csymbol,symboltype,entries in outputs: if (csymbol,symboltype,entries) not in inputs: f.write('uint%d_t static_%s[%d];\n' % (typebits[symboltype],csymbol,entries)) f.write('\n') f.write('int main(int argc,char **argv)\n') f.write('{\n') for csymbol,symboltype,entries in inputs: f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (typebits[symboltype],csymbol,entries*typebits[symboltype]/8)) for csymbol,symboltype,entries in outputs: if (csymbol,symboltype,entries) not in inputs: f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (typebits[symboltype],csymbol,entries*typebits[symboltype]/8)) f.write('\n') # XXX: resource limits if analysis == 'execute': for csymbol,symboltype,entries in inputs: f.write(' for (long long i = 0;i < %d;++i) {\n' % entries) f.write(' unsigned long long x;\n') f.write(' if (scanf("%llu",&x) != 1) abort();\n') f.write(' static_%s[i] = x;\n' % csymbol) f.write(' }\n') f.write('\n') if analysis in ('execute','angr'): for csymbol,symboltype,entries in inputs: f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' alloc_%s[i] = static_%s[i];\n' % (csymbol,csymbol)) f.write('\n') f.write(' ') if funret != None: f.write('%s[0] = ' % funret) f.write('%s(%s);\n' % (funname,','.join(funargs))) f.write('\n') if analysis in ('execute','angr'): for csymbol,symboltype,entries in outputs: f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' static_%s[i] = alloc_%s[i];\n' % (csymbol,csymbol)) f.write('\n') if analysis == 'execute': for csymbol,symboltype,entries in outputs: f.write(' for (long long i = 0;i < %d;++i) {\n' % entries) f.write(' unsigned long long x = static_%s[i];\n' % csymbol) f.write(' printf("%llu\\n",x);\n') f.write(' }\n') f.write(' fflush(stdout);\n') f.write('\n') f.write(' return 0;\n') f.write('}\n') # ----- rust if rust: os.makedirs('%s/src/bin'%builddir,exist_ok=True) with open('%s/Cargo.toml'%builddir,'a') as f: f.write("""\ [build-dependencies] cc = { git = "https://github.com/alexcrichton/cc-rs", version = "1.0.67", rev = "3283434ee41f2a1be6c78fe6d6a3ec8eb92b1833" } [[bin]] name = "analysis-execute" [[bin]] name = "analysis-valgrind" [[bin]] name = "analysis-angr" [profile.release] panic = "abort" """) with open('%s/build.rs'%builddir,'w') as f: f.write("""\ use std::env; fn main() { let out_dir = env::var("OUT_DIR").unwrap(); println!("cargo:rustc-link-search=native={}",out_dir); cc::Build::new() .cargo_metadata(false) .file("analysis-execute.c") .compile("analysis-execute"); println!("cargo:rerun-if-changed=analysis-execute.c"); cc::Build::new() .cargo_metadata(false) .file("analysis-valgrind.c") .compile("analysis-valgrind"); println!("cargo:rerun-if-changed=analysis-valgrind.c"); cc::Build::new() .cargo_metadata(false) .file("analysis-angr.c") .compile("analysis-angr"); println!("cargo:rerun-if-changed=analysis-angr.c"); } """) for analysis in 'execute','valgrind','angr': with open('%s/src/bin/analysis-%s.rs'%(builddir,analysis),'w') as f: f.write('#![no_std]\n') f.write('#![no_main]\n') f.write('extern crate %s;\n'%funname) f.write('\n') f.write('use core::convert::TryInto;\n') f.write('use core::slice;\n') f.write('\n') f.write('#[link(name="analysis-%s")]\n'%analysis) f.write('extern {\n') f.write('#[allow(dead_code)]\n') f.write(' fn main() -> i32;\n') f.write('}\n') f.write('\n') f.write('#[no_mangle]\n') f.write('pub unsafe extern "C"\n') f.write('fn %s(\n'%funname) for arg,argtype in zip(crustargs,crustargtypes): f.write(' %s:%s,\n' % (arg,argtype)) f.write(') -> i32 {\n') for csymbol,symboltype,entries in outputs: f.write(' let rust_%s = slice::from_raw_parts_mut(c_%s,%d);\n' % (csymbol,csymbol,entries)) f.write(' let rust_%s:&mut[%s;%d] = rust_%s.try_into().unwrap();\n' % (csymbol,rusttype[symboltype],entries,csymbol)) for csymbol,symboltype,entries in inputs: if (csymbol,symboltype,entries) not in outputs: f.write(' let rust_%s = slice::from_raw_parts(c_%s,%d);\n' % (csymbol,csymbol,entries)) f.write(' %s::%s(%s);\n' % (funname,funname,','.join(rustargs))) # XXX: return value f.write(' 0\n') f.write('}\n') cargotime = -cputime() env = os.environ env['CC'] = compiler.split()[0] env['CFLAGS'] = ' '.join(compiler.split()[1:]) env['RUSTFLAGS'] = '-C target_cpu=native -C link-args=-no-pie' # XXX: let user control this? if rustlocked: command = 'cargo build -q --locked --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr' else: command = 'cargo build -q --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr' try: proc = subprocess.Popen(command.split(),cwd=builddir,env=env,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-cargofailed',traceback.format_exc()) return o,p,i,compiler,False assert not err if out != '': note(builddir,'warning-cargooutput',out) if proc.returncode: note(builddir,'warning-cargofailed','exit code %s' % proc.returncode) return o,p,i,compiler,False for analysis in 'execute','valgrind','angr': shutil.copy('%s/target/release/analysis-%s'%(builddir,analysis),'%s/analysis-%s'%(builddir,analysis)) cargotime += cputime() notetime(builddir,'cargo',cargotime) return o,p,i,compiler,True # ----- compile compiletime = -cputime() objfiles = [] for f in files+['analysis-execute.c','analysis-valgrind.c','analysis-angr.c']: command = '%s -Wall -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler,f) try: proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-compilefailed-%s'%f,traceback.format_exc()) return o,p,i,compiler,False assert not err if out != '': note(builddir,'warning-compileoutput-%s'%f,out) if proc.returncode: note(builddir,'warning-compilefailed-%s'%f,'exit code %s' % proc.returncode) return o,p,i,compiler,False if f in files: objfiles += ['.'.join(f.split('.')[:-1]+['o'])] compiletime += cputime() notetime(builddir,'compile',compiletime) # ----- link into executable linktime = -cputime() for analysis in 'execute','valgrind','angr': static = True if static: command = 'gcc -no-pie -o analysis-%s analysis-%s.o' % (analysis,analysis) command = command.split() command += objfiles else: command = 'gcc -shared -Wl,-soname,library.so.1 -o library.so.1' command = command.split() command += objfiles try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput-%s'%analysis,out) assert not err if proc.returncode: note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s' % proc.returncode) return o,p,i,compiler,False shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir) command = 'gcc -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir) command = command.split() try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput-%s'%analysis,out) assert not err if proc.returncode: note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s' % proc.returncode) return o,p,i,compiler,False linktime += cputime() notetime(builddir,'link',linktime) return o,p,i,compiler,True def wanttocompile(): for o,p in primitives: for i in opimplementations[o,p]: for compiler in compilerlist: yield o,p,i,compiler op_compiled = {} for o,p in primitives: op_compiled[o,p] = [] with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile()): if not ok: continue op_compiled[o,p] += [(i,compiler)] print('----- execute') op_x = {} for o,p in primitives: inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p] op_x[o,p] = [] for execution in range(numrandomtests): x = [] for csymbol,symboltype,entries in inputs: for e in range(entries): if execution == 0: r = 0 elif execution == 1: r = 2**typebits[symboltype]-1 else: r = random.randrange(2**typebits[symboltype]) x += [r] op_x[o,p] += [x] def execute(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p] executetime = -cputime() results = [] command = ['./analysis-execute'] for x in op_x[o,p]: xstr = '' for r in x: xstr += '%d\n'%r try: proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) ystr,err = proc.communicate(input=xstr) except OSError: note(builddir,'warning-executeerror',xstr) return o,p,i,compiler,False if proc.returncode != 0: note(builddir,'warning-executefailed',xstr+'exit code %s' % proc.returncode) return o,p,i,compiler,False try: y = [int(s) for s in ystr.splitlines()] ypos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): assert y[ypos] >= 0 assert y[ypos] < 2**typebits[symboltype] ypos += 1 assert ypos == len(y) except ValueError: note(builddir,'warning-executebadformat',input_example_str(inputs,x)+output_example_str(outputs,y)) return o,p,i,compiler,False results += [y] executetime += cputime() notetime(builddir,'execute',executetime) return o,p,i,compiler,results def wanttoexecute(): for o,p in primitives: for i,compiler in op_compiled[o,p]: yield o,p,i,compiler opic_y = {} with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute()): if results == False: continue opic_y[o,p,i,compiler] = results print('----- valgrind (can take some time)') def valgrind(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) valgrindtime = -cputime() command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind'] valgrindstatus = None try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except OSError: valgrindstatus = 'warning-valgrinderror' if valgrindstatus == None: assert not err if proc.returncode == 99: valgrindstatus = 'unsafe-valgrindfailure' elif proc.returncode != 0: valgrindstatus = 'warning-valgrinderror' elif out.find('client request') >= 0: valgrindstatus = 'unsafe-valgrindfailure' if valgrindstatus != None: note(builddir,valgrindstatus) valgrindtime += cputime() notetime(builddir,'valgrind',valgrindtime) def wanttovalgrind(): for o,p in primitives: for i,compiler in op_compiled[o,p]: yield o,p,i,compiler with multiprocessing.Pool(os_cores) as pool: list(pool.starmap(valgrind,wanttovalgrind())) print('----- unroll (can take tons of time)') # XXX: could do this in parallel with valgrind # XXX: unrolled can be huge; pass through disk instead of RAM def values(terms,replacements): # input: replacements mapping cache_key to integers # output: dictionary V mapping cache_key to pairs (b,i) where i is a b-bit value # output includes all terms # _or_ output is None if terms use variables outside replacements V = {} W = set() # warnings def evaluate(t): if t.cache_key in V: return True if t.op == 'BoolV': V[t.cache_key] = 1,t.args[0] return True if t.op == 'BVV': V[t.cache_key] = t.size(),t.args[0] return True if t.op == 'BVS': if t.cache_key not in replacements: return False V[t.cache_key] = t.size(),replacements[t.cache_key].args[0] return True if t.op == 'Extract': assert len(t.args) == 3 top = t.args[0] bot = t.args[1] if not evaluate(t.args[2]): return False x0 = V[t.args[2].cache_key] assert x0[0] > top assert top >= bot assert bot >= 0 V[t.cache_key] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot) return True if t.op in ('SignExt','ZeroExt'): assert len(t.args) == 2 if not evaluate(t.args[1]): return False x0bits,x0 = V[t.args[1].cache_key] extend = t.args[0] assert extend >= 0 if t.op == 'SignExt': if x0 >= (1<<(x0bits-1)): x0 -= 1<<x0bits x0 += 1<<(x0bits+extend) V[t.cache_key] = x0bits+extend,x0 return True for a in t.args: if not evaluate(a): return False x = [V[a.cache_key] for a in t.args] if t.op == 'Concat': y = 0 ybits = 0 for xbitsi,xi in x: y <<= xbitsi y += xi ybits += xbitsi V[t.cache_key] = ybits,y return True if t.op in ('__eq__','__ne__'): assert len(x) == 2 assert x[0][0] == x[1][0] if t.op == '__eq__': V[t.cache_key] = 1,(x[0][1]==x[1][1]) elif t.op == '__ne__': V[t.cache_key] = 1,(x[0][1]!=x[1][1]) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'): bits = x[0][0] assert all(xi[0] == bits for xi in x) if t.op == '__and__': reduction = (lambda s,t:s&t) elif t.op == '__or__': reduction = (lambda s,t:s|t) elif t.op == '__xor__': reduction = (lambda s,t:s^t) elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits)) elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits)) elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits)) elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits)) elif t.op == '__rshift__': def reduction(s,t): flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR usigned = ssigned >> tsigned return usigned%(2**bits) elif t.op == '__mul__': reduction = (lambda s,t:(s*t)%(2**bits)) W.add('mul') elif t.op == '__floordiv__': def reduction(s,t): if t == 0: return 0 return (s//t)%(2**bits) W.add('div') elif t.op == '__mod__': def reduction(s,t): if t == 0: return s return (s%t)%(2**bits) W.add('div') elif t.op == 'SDiv': def reduction(s,t): if t == 0: return 0 flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # sdiv definition in Z3: # - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0]. # - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0]. if ssigned*tsigned >= 0: usigned = ssigned // tsigned else: usigned = -((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') elif t.op == 'SMod': def reduction(s,t): if t == 0: return s flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # srem definition in Z3: # It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division. # The most significant bit (sign) of the result is equal to the most significant bit of \c t1. if ssigned*tsigned >= 0: usigned = ssigned - tsigned*(ssigned // tsigned) else: usigned = ssigned + tsigned*((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') else: print('values: internal error %s, falling back to Z3' % t.op) return False V[t.cache_key] = bits,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == '__invert__': assert len(x) == 1 bits = x[0][0] V[t.cache_key] = bits,(1<<bits)-1-x[0][1] return True if t.op == 'Not': assert len(x) == 1 assert all(xi[0] == 1 for xi in x) V[t.cache_key] = 1,1-x[0][1] return True if t.op in ('And','Or'): assert all(xi[0] == 1 for xi in x) if t.op == 'And': reduction = (lambda s,t:s*t) elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t) else: print('values: internal error %s, falling back to Z3' % t.op) return False V[t.cache_key] = 1,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == 'If': assert len(x) == 3 assert x[0][0] == 1 if x[0][1]: V[t.cache_key] = x[1] else: V[t.cache_key] = x[2] return True if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'): assert len(x) == 2 bits = x[0][0] assert bits == x[1][0] flip = 2**(bits-1) x0,x1 = x[0][1],x[1][1] if t.op == '__le__': V[t.cache_key] = (1,x0<=x1) elif t.op == 'ULE': V[t.cache_key] = (1,x0<=x1) elif t.op == '__lt__': V[t.cache_key] = (1,x0<x1) elif t.op == 'ULT': V[t.cache_key] = (1,x0<x1) elif t.op == '__ge__': V[t.cache_key] = (1,x0>=x1) elif t.op == 'UGE': V[t.cache_key] = (1,x0>=x1) elif t.op == '__gt__': V[t.cache_key] = (1,x0>x1) elif t.op == 'UGT': V[t.cache_key] = (1,x0>x1) elif t.op == 'SLE': V[t.cache_key] = (1,(x0^flip)<=(x1^flip)) elif t.op == 'SLT': V[t.cache_key] = (1,(x0^flip)<(x1^flip)) elif t.op == 'SGE': V[t.cache_key] = (1,(x0^flip)>=(x1^flip)) elif t.op == 'SGT': V[t.cache_key] = (1,(x0^flip)>(x1^flip)) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True # XXX: add support for more print('values: unsupported operation %s, falling back to Z3' % t.op) return False # XXX: also add more validation for all of the above for t in terms: if not evaluate(t): return None,W return V,W def unroll_print(outputs,unrolled,f): walked = {} def walk(t): if t.cache_key in walked: return walked[t.cache_key] if t.op == 'BoolV': walknext = len(walked)+1 f.write('v%d = bool(%d)\n' % (walknext,t.args[0])) elif t.op == 'BVV': walknext = len(walked)+1 f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0])) elif t.op == 'BVS': walknext = len(walked)+1 f.write('v%d = %s\n' % (walknext,t.args[0])) elif t.op == 'Extract': assert len(t.args) == 3 input = 'v%d' % walk(t.args[2]) walknext = len(walked)+1 f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1])) elif t.op in ['SignExt','ZeroExt']: assert len(t.args) == 2 input = 'v%d' % walk(t.args[1]) walknext = len(walked)+1 f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0])) else: inputs = ['v%d' % walk(a) for a in t.args] walknext = len(walked)+1 f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs))) walked[t.cache_key] = walknext return walknext for x in unrolled: walk(x) unrolledpos = 0 for csymbol,symboltype,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos]))) unrolledpos += 1 def unroll_inputvars(inputs): result = [] for csymbol,symboltype,entries in inputs: for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) result += [(varname,variable)] return result # XXX: probably better to merge into unroll() def unroll_worker(binary,inputs,outputs,rust): results = [] if rust: proj = angr.Project(binary,auto_load_libs=False) else: avoidsimprocedures = ( 'memcmp', # want to test the real libc memcmp ) proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=True) state = proj.factory.full_init_state() state.options |= {angr.options.LAZY_SOLVES} state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY} state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS} state.options -= {angr.options.SIMPLIFY_EXPRS} state.options -= {angr.options.SIMPLIFY_REGISTER_WRITES} state.options -= {angr.options.SIMPLIFY_MEMORY_WRITES} state.options -= {angr.options.SIMPLIFY_REGISTER_READS} state.options -= {angr.options.SIMPLIFY_MEMORY_READS} for csymbol,symboltype,entries in inputs: xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) if typebits[symboltype] == 8: state.mem[xaddr+i].char = variable elif typebits[symboltype] == 16: state.mem[xaddr+2*i].short = variable elif typebits[symboltype] == 32: state.mem[xaddr+4*i].int = variable elif typebits[symboltype] == 64: state.mem[xaddr+8*i].long = variable simgr = proj.factory.simgr(state) while True: if len(simgr.errored) > 0: return -1,False,simgr.errored if len(simgr.deadended)+len(simgr.active) > maxsplit: return -1,False,'saferewrite limiting split to %d'%maxsplit if len(simgr.active) == 0: break simgr.step() exits = simgr.deadended assert len(exits) > 0 # cannot be safe if there are multiple exits # for equivalence tests we'll merge exits below mergedconstraints = [] for epos,e in enumerate(exits): mergedconstraint = e.solver.true for c in e.solver.constraints: mergedconstraint = e.solver.And(mergedconstraint,c) mergedconstraints += [mergedconstraint] resultpos = 0 for csymbol,symboltype,entries in outputs: xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr for i in range(entries): if typebits[symboltype] == 8: xi = e.mem[xaddr+i].char.resolved elif typebits[symboltype] == 16: xi = e.mem[xaddr+2*i].short.resolved elif typebits[symboltype] == 32: xi = e.mem[xaddr+4*i].int.resolved elif typebits[symboltype] == 64: xi = e.mem[xaddr+8*i].long.resolved if epos == 0: assert len(results) == resultpos results += [xi] else: results[resultpos] = e.solver.If(mergedconstraint,xi,results[resultpos]) resultpos += 1 assert resultpos == len(results) assert len(mergedconstraints) == len(exits) ispartition = True # are mergedconstraints a partition of all universes? # i.e.: in each universe, exactly one of the constraints is satisfied? s = claripy.Solver(timeout=z3timeout) for c in mergedconstraints: s.add(claripy.Not(c)) if s.satisfiable(): ispartition = False for i in range(len(exits)): for j in range(i): s = claripy.Solver(timeout=z3timeout) s.add(mergedconstraints[i]) s.add(mergedconstraints[j]) if s.satisfiable(): ispartition = False return len(exits),ispartition,results def unroll(o,p,i,compiler): compilerword = compiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p] rust = 'Cargo.toml' in os.listdir(implementationdir) unrolltime = -cputime() numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs,rust) if numexits < 1: note(builddir,'warning-unrollerror',unrolled) return o,p,i,compiler,False if not ispartition: note(builddir,'warning-unrollnotpartition') return o,p,i,compiler,False if numexits > 1: note(builddir,'unsafe-unrollsplit-%d'%numexits) with open('%s/analysis/unrolled' % builddir,'w') as f: unroll_print(outputs,unrolled,f) okvars = set(vname for vname,v in unroll_inputvars(inputs)) usedvars = set(v for x in unrolled for v in x.variables) if not usedvars.issubset(okvars): note(builddir,'warning-unrollmem') if not okvars.issubset(usedvars): note(builddir,'warning-unusedinputs') Wdone = set() for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]): # cpu gave us outputs y given inputs x # does this match unrolled? replacements = {} xpos = 0 for csymbol,symboltype,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) replacements[variable.cache_key] = claripy.BVV(x[xpos],typebits[symboltype]) xpos += 1 assert xpos == len(x) mismatch = True notestr = 'internal error\n' V = None try: V,W = values(unrolled,replacements) for warn in W: if warn in Wdone: continue Wdone.add(warn) note(builddir,'warning-%s'%warn) except AssertionError: note(builddir,'warning-valuesfailed',traceback.format_exc()) # proceed with z3 fallback below if V != None: mismatch = any(yi != V[unrolledi.cache_key][1] for (yi,unrolledi) in zip(y,unrolled)) if mismatch: notestr = '# mismatch between CPU execution of binary and saferewrite execution of unrolled:\n' for (vname,v),xi in zip(unroll_inputvars(inputs),x): notestr += '%s = %d\n' % (vname,xi) pos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) notestr += 'executed_%s = %s\n' % (varname,y[pos]) notestr += 'unrolled_%s = %s\n' % (varname,V[unrolled[pos].cache_key][1]) pos += 1 else: # fall back on Z3 for figuring this out s = claripy.Solver(timeout=z3timeout) mismatch = claripy.false for yi,unrolledi in zip(y,unrolled): mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi) s.add(mismatch) mismatch = s.satisfiable() if mismatch: notestr = '# hint: this type of mismatch typically reflects reading undefined memory.\n' notestr = '# mismatch between CPU execution of binary and Z3 execution of unrolled:\n' for vname,v in unroll_inputvars(inputs): notestr += '%s = %s\n' % (vname,s.eval(v,1)[0]) pos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) notestr += 'executed_%s = %s\n' % (varname,y[pos]) notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0]) pos += 1 if mismatch: note(builddir,'warning-unrollmismatch',notestr) return o,p,i,compiler,False unrolltime += cputime() notetime(builddir,'unroll',unrolltime) return o,p,i,compiler,unrolled def wanttounroll(): for o,p in primitives: for i,compiler in op_compiled[o,p]: if (o,p,i,compiler) in opic_y: yield o,p,i,compiler opic_unrolled = {} with multiprocessing.Pool(os_cores) as pool: for o,p,i,compiler,unrolled in pool.starmap(unroll,wanttounroll()): if unrolled == False: continue opic_unrolled[o,p,i,compiler] = unrolled print('----- compareunrolled (can take tons of time)') def compareunrolled(o,p,i,compiler,source,sourcecompiler): compilerword = compiler.replace(' ','_').replace('=','_') sourcecompilerword = sourcecompiler.replace(' ','_').replace('=','_') implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,compilerword) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype = op_api[o,p] randomtestspass = True for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])): if y != z: xstr = input_example_str(inputs,x) note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourcecompilerword),xstr) randomtestspass = False # continue through random tests to get statistics if not randomtestspass: if not satvalidation1: return equivtime = -cputime() u1 = opic_unrolled[o,p,source,sourcecompiler] u2 = opic_unrolled[o,p,i,compiler] assert len(u1) == len(u2) # XXX: allow other equivalence-testing techniques s = claripy.Solver(timeout=z3timeout) different = claripy.false for u1j,u2j in zip(u1,u2): different = claripy.Or(different,u1j != u2j) s.add(different) try: mismatch = s.satisfiable() except claripy.errors.ClaripyZ3Error: # avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887 note(builddir,'warning-z3failed',traceback.format_exc()) return if mismatch: # angr documentation says: # "If you don't add any constraints between two queries, the results will be consistent with each other." example = '' for vname,v in unroll_inputvars(inputs): example += '%s = %s\n' % (vname,s.eval(v,1)[0]) unrolledpos = 0 for csymbol,symboltype,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0]) example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0]) unrolledpos += 1 note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourcecompilerword),example) else: note(builddir,'equals-%s-%s' % (source,sourcecompilerword)) equivtime += cputime() notetime(builddir,'equiv',equivtime) def wanttocompareunrolled(): for o,p in primitives: for i,compiler in op_compiled[o,p]: source = 'ref' # XXX: allow each implementation to choose source if i == 'ref': sourcecompiler = compilerlist[0] # XXX: maybe also allow choice else: sourcecompiler = compiler if (o,p,i,compiler) not in opic_unrolled: continue if (o,p,source,sourcecompiler) not in opic_unrolled: continue # XXX: could also do self-tests if (o,p,source,sourcecompiler) == (o,p,i,compiler): continue yield o,p,i,compiler,source,sourcecompiler with multiprocessing.Pool(os_cores) as pool: list(pool.starmap(compareunrolled,wanttocompareunrolled()))