-rwxr-xr-x 54459 saferewrite-20250228/analyze raw
#!/usr/bin/env python3 import sys import os import logging import pickle import io import shutil import subprocess import multiprocessing import random import traceback import functools import setproctitle import resource import angr import claripy import struct signedtypes = set(('int8','int16','int32','int64')) typebits = { 'int8': 8, 'int16': 16, 'int32': 32, 'int64': 64, 'uint8': 8, 'uint16': 16, 'uint32': 32, 'uint64': 64, } rusttype = { 'int8': 'i8', 'int16': 'i16', 'int32': 'i32', 'int64': 'i64', 'uint8': 'u8', 'uint16': 'u16', 'uint32': 'u32', 'uint64': 'u64', } typestruct = { 'int8': 'b', 'int16': 'h', 'int32': 'i', 'int64': 'q', 'uint8': 'B', 'uint16': 'H', 'uint32': 'I', 'uint64': 'Q', } # ===== performance-related configuration... numrandomtests = 16 # test 16 random inputs to each function # (first two are all-0, all-1) # of course more time on testing and fuzzing would be useful maxsplit = 300 # max number of universes within an angr run # XXX: allow per-primitive/per-implementation configuration satvalidation1 = True # True means: even if random tests fail, invoke general sat/unsat mechanism # this allows extra validation of that mechanism # the following is somewhat ad-hoc monkey-patching # underlying problem: flatten simplifiers can easily blow up try: claripy.simplifications.simpleton._simplifiers = { 'Reverse': claripy.simplifications.simpleton.bv_reverse_simplifier, 'Extract': claripy.simplifications.simpleton.extract_simplifier, 'Concat': claripy.simplifications.simpleton.concat_simplifier, 'ZeroExt': claripy.simplifications.simpleton.zeroext_simplifier, 'SignExt': claripy.simplifications.simpleton.signext_simplifier, } except: try: claripy.simpleton._simplifiers = { 'Reverse': claripy.simpleton.bv_reverse_simplifier, 'Extract': claripy.simpleton.extract_simplifier, 'Concat': claripy.simpleton.concat_simplifier, 'ZeroExt': claripy.simpleton.zeroext_simplifier, 'SignExt': claripy.simpleton.signext_simplifier, } except: claripy.simplifications._all_simplifiers = { 'Reverse': claripy.simplifications.bv_reverse_simplifier, 'Extract': claripy.simplifications.extract_simplifier, 'Concat': claripy.simplifications.concat_simplifier, 'ZeroExt': claripy.simplifications.zeroext_simplifier, 'SignExt': claripy.simplifications.signext_simplifier, } sys.setrecursionlimit(1000000) # ===== real work begins here z3timeout = 2**32-1 # XXX: by default claripy satisfiable() has 5-minute timeout # and indicates timeout by returning False # current Z3 documentation says: # timeout (in milliseconds) (UINT_MAX and 0 mean no timeout) try: os_threads = len(os.sched_getaffinity(0)) except AttributeError: os_threads = multiprocessing.cpu_count() os_threads = os.getenv('THREADS',default=os_threads) os_threads = int(os_threads) if os_threads < 1: os_threads = 1 def cputime(): return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime resource.setrlimit(resource.RLIMIT_NOFILE,2*resource.getrlimit(resource.RLIMIT_NOFILE)[1:]) def notetime(builddir,what,time): print('%s seconds %s %.6f' % (builddir,what,time)) sys.stdout.flush() with open('%s/analysis/seconds' % builddir,'a') as f: f.write('%s %.6f\n' % (what,time)) def note(builddir,conclusion,contents=None): print('%s %s' % (builddir,conclusion)) sys.stdout.flush() with open('%s/analysis/%s' % (builddir,conclusion),'w') as f: if contents is not None: f.write(str(contents)) startdir = os.getcwd() assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir) shutil.rmtree('build',ignore_errors=True) os.makedirs('build') print('===== collecting compilers') with open('compilers') as f: compilers = f.read().splitlines() compiler_pos_cached = {} for pos,compiler in enumerate(compilers): compiler_pos_cached[compiler] = pos def compiler_pos(compiler): return compiler_pos_cached[compiler] def compiler_list(): for compiler in compilers: compiler = compiler.strip() if ':' not in compiler: continue yield compiler def compiler_command(compiler): return ':'.join(compiler.split(':')[1:]).strip() def compiler_hastag(compiler,tag): return tag in compiler.split(':')[0].split() def compiler_hasarch(compiler,arch): if arch is None: return True return any(all(compiler_hastag(compiler,x) for x in a.split()) for a in arch) def compiler_word(compiler): return compiler_command(compiler).replace(' ','_').replace('=','_') compiler_endianness = {} def checkcompiler(compiler): word = compiler_word(compiler) builddir = f'build/compiler/{word}' os.makedirs(f'{builddir}/analysis',exist_ok=True) language = 'cpp' if compiler_hastag(compiler,'cpp') else 'c' with open(f'{builddir}/endianness.{language}','w') as f: f.write('''#include <stdlib.h> #include <inttypes.h> int main() { uint16_t a = 0x0201; uint32_t b = 0x04030201; uint64_t c = 0x0807060504030201ULL; char *A = (char *) &a; char *B = (char *) &b; char *C = (char *) &c; if (A[0] == 1 && A[1] == 2 && B[0] == 1 && B[1] == 2 && B[2] == 3 && B[3] == 4 && C[0] == 1 && C[1] == 2 && C[2] == 3 && C[3] == 4 && C[4] == 5 && C[5] == 6 && C[6] == 7 && C[7] == 8) return 12; if (A[0] == 2 && A[1] == 1 && B[0] == 4 && B[1] == 3 && B[2] == 2 && B[3] == 1 && C[0] == 8 && C[1] == 7 && C[2] == 6 && C[3] == 5 && C[4] == 4 && C[5] == 3 && C[6] == 2 && C[7] == 1) return 21; return 99; } ''') command = f'{compiler_command(compiler)} -static -o endianness endianness.{language}' try: proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,f'warning-compilefailed-endianness.{language}',traceback.format_exc()) return compiler,None,False assert not err if out != '': note(builddir,'warning-compileoutput-endianness.{language}',out) if proc.returncode: note(builddir,'warning-compilefailed-endianness.{language}',f'exit code {proc.returncode}\n') return compiler,None,False command = './endianness' try: proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,f'warning-executefailed-endianness',traceback.format_exc()) return compiler,None,False assert not err if out != '': note(builddir,'warning-executeoutput-endianness',out) if proc.returncode == 12: note(builddir,f'endianness-little') return compiler,'little',True if proc.returncode == 21: note(builddir,f'endianness-big') return compiler,'big',True note(builddir,'warning-executefailed-endianness',f'exit code {proc.returncode}\n') return compiler,None,False with multiprocessing.Pool(os_threads) as pool: for compiler,endianness,ok in pool.map(checkcompiler,compiler_list(),chunksize=1): if not ok: continue compiler_endianness[compiler] = endianness def compiler_list2(): for compiler in compilers: if compiler in compiler_endianness: yield compiler print('===== collecting primitives') primitives = [] for o in 'src',: o = o.strip() if o == '': continue if not os.path.isdir(o): continue if os.stat('%s' % o).st_mode & 0o1000 == 0o1000: print('%s sticky, skipping' % o) sys.stdout.flush() continue for p in sorted(os.listdir(o)): if p == 'compiler': continue if not os.path.isdir('%s/%s' % (o,p)): continue if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000: print('%s/%s sticky, skipping' % (o,p)) sys.stdout.flush() continue if not os.path.exists('%s/%s/api' % (o,p)): print('%s/%s/api nonexistent, skipping' % (o,p)) sys.stdout.flush() continue primitives += [(o,p)] op_api = {} for o,p in primitives: inputs = [] outputs = [] funargs = [] funargtypes = [] rustargs = [] rustargtypes = [] crustargs = [] crustargtypes = [] funname = None funret = None funrettype = 'void' rustfunrettype = None with open('%s/%s/api' % (o,p)) as f: for line in f: line = line.split() if len(line) == 0: continue if line[0] == 'call': funname = line[1] assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_' for x in funname) if line[0] == 'return': symboltype = line[1] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) entries = 1 outputs += [(csymbol,symboltype,entries)] funret = 'var_%s'%csymbol funrettype = f'{symboltype}_t' rustfunrettype = rusttype[symboltype] if line[0] in ('in','out','inout'): symboltype = line[1] csymbol = line[2] assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol) if len(line) == 3: pointer = False entries = 1 else: pointer = True entries = int(line[3]) if line[0] in ('in','inout'): inputs += [(csymbol,symboltype,entries)] if line[0] in ('out','inout'): outputs += [(csymbol,symboltype,entries)] if pointer: funargs += ['var_%s'%csymbol] if line[0] == 'in': funargtypes += [f'const {symboltype}_t *'] else: funargtypes += [f'{symboltype}_t *'] crustargs += ['c_%s'%csymbol] rustargs += ['rust_%s'%csymbol] if line[0] == 'in': crustargtypes += ['*const %s' % rusttype[symboltype]] else: crustargtypes += ['*mut %s' % rusttype[symboltype]] else: funargs += ['var_%s[0]'%csymbol] funargtypes += [f'{symboltype}_t'] crustargs += ['c_%s'%csymbol] rustargs += ['rust_%s'%csymbol] crustargtypes += ['const %s' % rusttype[symboltype]] # XXX: support constant inputs op_api[o,p] = inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype def input_example_str(inputs,x): xstr = '' xpos = 0 for csymbol,symboltype,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) xstr += '%s = %d\n' % (varname,x[xpos]) xpos += 1 assert xpos == len(x) return xstr def output_example_str(outputs,y): ystr = '' ypos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) ystr += '%s = %d\n' % (varname,y[ypos]) ypos += 1 assert ypos == len(y) return ystr reservedfilenames = ( 'library.so.1', 'unroll-pickle', 'unroll-log', 'analysis', 'analysis-execute', 'analysis-execute.c', 'analysis-execute.cpp', 'analysis-valgrind', 'analysis-valgrind.c', 'analysis-valgrind.cpp', 'analysis-angr', 'analysis-angr.c', 'analysis-angr.cpp', 'subroutines.c', 'subroutines.cpp', ) opi_language = {} opi_architectures = {} opimplementations = {} for o,p in primitives: opimplementations[o,p] = [] for i in sorted(os.listdir('%s/%s' % (o,p))): implementationdir = '%s/%s/%s' % (o,p,i) if not os.path.isdir(implementationdir): continue if os.stat(implementationdir).st_mode & 0o1000 == 0o1000: print('%s/%s/%s sticky, skipping' % (o,p,i)) continue files = sorted(os.listdir(implementationdir)) for f in files: ok = True if f in reservedfilenames: print('%s/%s/%s/%s reserved filename' % (o,p,i,f)) ok = False if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f): print('%s/%s/%s/%s prohibited character' % (o,p,i,f)) ok = False if not ok: continue opimplementations[o,p] += [i] language = 'c' for f in files: if f.endswith('cc') or f.endswith('.cpp'): language = 'cpp' opi_language[o,p,i] = language opi_architectures[o,p,i] = None if 'architectures' in files: with open(f'{implementationdir}/architectures') as f: opi_architectures[o,p,i] = f.read().splitlines() def compile(o,p,i,compiler): language = opi_language[o,p,i] word = compiler_word(compiler) implementationdir = '%s/%s/%s' % (o,p,i) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p] elfulator = compiler_hastag(compiler,'elfulator') files = sorted(os.listdir(implementationdir)) files += [f'subroutines.{language}'] rust = 'Cargo.toml' in files rustlocked = 'Cargo.lock' in files cfiles = [x for x in files if x.endswith('.c') or x.endswith('.cc') or x.endswith('.cpp')] sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')] files = cfiles + sfiles afiles = list(files) os.makedirs('build/%s/%s' % (p,i),exist_ok=True) builddir = 'build/%s/%s/%s' % (p,i,word) shutil.copytree(implementationdir,builddir) os.makedirs('%s/analysis' % builddir) for bits in 8,16,32,64: if not os.path.exists(f'{builddir}/crypto_int{bits}.h'): with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_int%d int%d_t' % (bits,bits)) if not os.path.exists(f'{builddir}/crypto_uint{bits}.h'): with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f: f.write('#include <inttypes.h>\n') f.write('#define crypto_uint%d uint%d_t' % (bits,bits)) with open(f'{builddir}/subroutines.{language}','w') as f: f.write('''#include <stddef.h> int memcmp(const void *u,const void *v,size_t n) { unsigned char *x = (unsigned char *) u; unsigned char *y = (unsigned char *) v; for (;n > 0;--n,++x,++y) { int result = *x; result -= *y; if (result) return result; } return 0; } int bcmp(const void *u,const void *v,size_t n) { return memcmp(u,v,n); } ''') for analysis in 'execute','valgrind','angr': afiles += ['analysis-%s.%s' % (analysis,language)] with open('%s/analysis-%s.%s' % (builddir,analysis,language),'w') as f: if analysis != 'valgrind': f.write('#include <unistd.h>\n') f.write('#include <stdlib.h>\n') f.write('#include <string.h>\n') f.write('#include <inttypes.h>\n') f.write('\n') # function declaration f.write('extern ') if funrettype is not None: f.write('%s ' % funrettype) f.write('%s(%s);\n' % (funname,','.join(funargtypes))) f.write('\n') avoidalloc = (analysis == 'execute') or (analysis == 'angr' and elfulator) if avoidalloc: for csymbol,symboltype,entries in inputs: f.write(f'{symboltype}_t var_{csymbol}[{entries}];\n') for csymbol,symboltype,entries in outputs: if (csymbol,symboltype,entries) not in inputs: f.write(f'{symboltype}_t var_{csymbol}[{entries}];\n') f.write('\n') f.write('int main(int argc,char **argv)\n') f.write('{\n') if not avoidalloc: for csymbol,symboltype,entries in inputs: allocbytes = entries*typebits[symboltype]//8 f.write(f' {symboltype}_t *var_{csymbol} = ({symboltype}_t *) malloc({allocbytes});\n') for csymbol,symboltype,entries in outputs: if (csymbol,symboltype,entries) not in inputs: allocbytes = entries*typebits[symboltype]//8 f.write(f' {symboltype}_t *var_{csymbol} = ({symboltype}_t *) malloc({allocbytes});\n') f.write('\n') # XXX: resource limits if analysis != 'valgrind': for csymbol,symboltype,entries in inputs: allocbytes = typebits[symboltype]//8 f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' if (read(0,(char *) &var_%s[i],%s) != %s)\n' % (csymbol,allocbytes,allocbytes)) f.write(' exit(111);\n') f.write('\n') f.write(' ') if funret is not None: f.write('%s[0] = ' % funret) f.write('%s(%s);\n' % (funname,','.join(funargs))) f.write('\n') if analysis != 'valgrind': for csymbol,symboltype,entries in outputs: allocbytes = typebits[symboltype]//8 f.write(' for (long long i = 0;i < %d;++i)\n' % entries) f.write(' if (write(1,(char *) &var_%s[i],%s) != %s)\n' % (csymbol,allocbytes,allocbytes)) f.write(' exit(111);\n') f.write('\n') f.write(' return 0;\n') f.write('}\n') # ===== rust if rust: os.makedirs('%s/src/bin'%builddir,exist_ok=True) with open('%s/Cargo.toml'%builddir,'a') as f: f.write("""\ [[bin]] name = "analysis-execute" [[bin]] name = "analysis-valgrind" [[bin]] name = "analysis-angr" [profile.release] panic = "abort" """) with open('%s/build.rs'%builddir,'w') as f: f.write("""\ use std::env; fn main() { let out_dir = env::var("OUT_DIR").unwrap(); println!("cargo:rustc-link-search=native={}",out_dir); cc::Build::new() .cargo_metadata(false) .file("analysis-execute.c") .compile("analysis-execute"); println!("cargo:rerun-if-changed=analysis-execute.c"); cc::Build::new() .cargo_metadata(false) .file("analysis-valgrind.c") .compile("analysis-valgrind"); println!("cargo:rerun-if-changed=analysis-valgrind.c"); cc::Build::new() .cargo_metadata(false) .file("analysis-angr.c") .compile("analysis-angr"); println!("cargo:rerun-if-changed=analysis-angr.c"); } """) for analysis in 'execute','valgrind','angr': with open('%s/src/bin/analysis-%s.rs'%(builddir,analysis),'w') as f: f.write('#![no_std]\n') f.write('#![no_main]\n') f.write('extern crate %s;\n'%funname) f.write('\n') if any(f'rust_{csymbol}' in rustargs for (csymbol,symboltype,entries) in outputs): f.write('use core::convert::TryInto;\n') f.write('use core::slice;\n') f.write('\n') f.write('#[link(name="analysis-%s")]\n'%analysis) f.write('extern {\n') f.write('#[allow(dead_code)]\n') f.write(' fn main() -> i32;\n') f.write('}\n') f.write('\n') f.write('#[no_mangle]\n') f.write('pub unsafe extern "C"\n') f.write('fn %s(\n'%funname) for arg,argtype in zip(crustargs,crustargtypes): f.write(' %s:%s,\n' % (arg,argtype)) if rustfunrettype is None: f.write(') -> i32 {\n') else: f.write(') -> %s {\n' % rustfunrettype) for csymbol,symboltype,entries in outputs: if f'rust_{csymbol}' not in rustargs: continue f.write(' let rust_%s = slice::from_raw_parts_mut(c_%s,%d);\n' % (csymbol,csymbol,entries)) f.write(' let rust_%s:&mut[%s;%d] = rust_%s.try_into().unwrap();\n' % (csymbol,rusttype[symboltype],entries,csymbol)) for csymbol,symboltype,entries in inputs: if (csymbol,symboltype,entries) not in outputs: f.write(' let rust_%s = slice::from_raw_parts(c_%s,%d);\n' % (csymbol,csymbol,entries)) if rustfunrettype is None: f.write(' %s::%s(%s);\n' % (funname,funname,','.join(rustargs))) f.write(' 0\n') else: f.write(' %s::%s(%s)\n' % (funname,funname,','.join(rustargs))) f.write('}\n') cargotime = -cputime() env = os.environ env['CC'] = compiler_command(compiler).split()[0] env['CFLAGS'] = ' '.join(compiler_command(compiler).split()[1:]) env['RUSTFLAGS'] = '-C target_cpu=native -C link-args=-no-pie' # XXX: let user control this? if rustlocked: command = 'cargo build -q --locked --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr' else: command = 'cargo build -q --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr' try: proc = subprocess.Popen(command.split(),cwd=builddir,env=env,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,'warning-cargofailed',traceback.format_exc()) return o,p,i,compiler,False assert not err if out != '': note(builddir,'warning-cargooutput',out) if proc.returncode: note(builddir,'warning-cargofailed','exit code %s\n' % proc.returncode) return o,p,i,compiler,False for analysis in 'execute','valgrind','angr': shutil.copy('%s/target/release/analysis-%s'%(builddir,analysis),'%s/analysis-%s'%(builddir,analysis)) cargotime += cputime() notetime(builddir,'cargo',cargotime) return o,p,i,compiler,True # ===== compile compiletime = -cputime() objfiles = [] for f in afiles: command = '%s -Wall -Wextra -Wno-unused-function -Wno-unused-parameter -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler_command(compiler),f) if f in ('analysis-valgrind.c','analysis-valgrind.cpp'): command += ' -Wno-uninitialized' try: proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,'warning-compilefailed-%s'%f,traceback.format_exc()) return o,p,i,compiler,False assert not err if out != '': note(builddir,'warning-compileoutput-%s'%f,out) if proc.returncode: note(builddir,'warning-compilefailed-%s'%f,'exit code %s\n' % proc.returncode) return o,p,i,compiler,False if f in files: objfiles += ['.'.join(f.split('.')[:-1]+['o'])] compiletime += cputime() notetime(builddir,'compile',compiletime) # ===== link into executable linktime = -cputime() for analysis in 'execute','valgrind','angr': directlink = True if directlink: command = f'{compiler_command(compiler)} -no-pie -o analysis-{analysis} analysis-{analysis}.o' command = command.split() command += objfiles else: command = f'{compiler_command(compiler)} -shared -Wl,-soname,library.so.1 -o library.so.1' command = command.split() command += objfiles try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput-%s'%analysis,out) assert not err if proc.returncode: note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s\n' % proc.returncode) return o,p,i,compiler,False shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir) command = f'{compiler_command(compiler)} -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir) command = command.split() try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc()) return o,p,i,compiler,False if out != '': note(builddir,'warning-linkoutput-%s'%analysis,out) assert not err if proc.returncode: note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s\n' % proc.returncode) return o,p,i,compiler,False linktime += cputime() notetime(builddir,'link',linktime) return o,p,i,compiler,True def wanttocompile(): for o,p in primitives: for i in opimplementations[o,p]: for compiler in compiler_list2(): if compiler_hastag(compiler,opi_language[o,p,i]): if compiler_hasarch(compiler,opi_architectures[o,p,i]): yield o,p,i,compiler op_compiled = {} for o,p in primitives: op_compiled[o,p] = [] with multiprocessing.Pool(os_threads) as pool: for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile(),chunksize=1): if not ok: continue op_compiled[o,p] += [(i,compiler)] print('===== execute') op_x = {} for o,p in primitives: inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p] op_x[o,p] = [] for execution in range(numrandomtests): x = [] for csymbol,symboltype,entries in inputs: for e in range(entries): if execution == 0: r = 0 elif execution == 1: r = -1 if symboltype in signedtypes else 2**typebits[symboltype]-1 else: r = random.randrange(2**typebits[symboltype]) if symboltype in signedtypes: r -= 2**(typebits[symboltype]-1) x += [r] op_x[o,p] += [x] def execute(o,p,i,compiler): word = compiler_word(compiler) implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,word) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p] executetime = -cputime() results = [] command = ['./analysis-execute'] for x in op_x[o,p]: xstr = b'' xpos = 0 for csymbol,symboltype,entries in inputs: for e in range(entries): if symboltype in signedtypes: assert x[xpos] >= -2**(typebits[symboltype]-1) assert x[xpos] < 2**(typebits[symboltype]-1) else: assert x[xpos] >= 0 assert x[xpos] < 2**typebits[symboltype] structfmt = '<' if compiler_endianness[compiler] == 'little' else '>' structfmt += typestruct[symboltype] xstr += struct.pack(structfmt,x[xpos]) xpos += 1 assert xpos == len(x) try: proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) ystr,err = proc.communicate(input=xstr) except (OSError,FileNotFoundError): note(builddir,'warning-executeerror',xstr) return o,p,i,compiler,False if proc.returncode != 0: note(builddir,'warning-executefailed',f'exit code {proc.returncode}\n'+input_example_str(inputs,x)) return o,p,i,compiler,False try: y = [] ystrpos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): structfmt = '<' if compiler_endianness[compiler] == 'little' else '>' structfmt += typestruct[symboltype] structbytes = typebits[symboltype]//8 y += list(struct.unpack(structfmt,ystr[ystrpos:ystrpos+structbytes])) if symboltype in signedtypes: assert y[-1] >= -2**(typebits[symboltype]-1) assert y[-1] < 2**(typebits[symboltype]-1) else: assert y[-1] >= 0 assert y[-1] < 2**typebits[symboltype] ystrpos += structbytes assert ystrpos == len(ystr) except ValueError: note(builddir,'warning-executebadformat',input_example_str(inputs,x)) return o,p,i,compiler,False results += [y] executetime += cputime() notetime(builddir,'execute',executetime) return o,p,i,compiler,results def wanttoexecute(): for o,p in primitives: for i,compiler in op_compiled[o,p]: yield o,p,i,compiler opic_y = {} with multiprocessing.Pool(os_threads) as pool: for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute(),chunksize=1): if results == False: continue opic_y[o,p,i,compiler] = results print('===== valgrind (can take some time)') def valgrind(o,p,i,compiler): word = compiler_word(compiler) implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,word) valgrindtime = -cputime() command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind'] valgrindstatus = None try: proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) out,err = proc.communicate() except (OSError,FileNotFoundError): valgrindstatus = 'warning-valgrinderror' if valgrindstatus is None: assert not err if proc.returncode == 99: valgrindstatus = 'unsafe-valgrindfailure' elif proc.returncode != 0: valgrindstatus = 'warning-valgrinderror' elif out.find('client request') >= 0: valgrindstatus = 'unsafe-valgrindfailure' else: valgrindstatus = 'passed-valgrind' if valgrindstatus is not None: note(builddir,valgrindstatus) valgrindtime += cputime() notetime(builddir,'valgrind',valgrindtime) def wanttovalgrind(): for o,p in primitives: for i,compiler in op_compiled[o,p]: if compiler_hastag(compiler,'valgrind'): yield o,p,i,compiler with multiprocessing.Pool(os_threads) as pool: list(pool.starmap(valgrind,wanttovalgrind(),chunksize=1)) print('===== unroll (can take tons of time)') # XXX: could do this in parallel with valgrind def clarikey(e): try: # e.g. angr 9.2.102 return e.cache_key except: # e.g. angr 9.2.144 return e.hash() def values(terms,replacements): # input: replacements mapping clarikey to integers # output: dictionary V mapping clarikey to pairs (b,i) where i is a b-bit value # output includes all terms # _or_ output is None if terms use variables outside replacements V = {} W = set() # warnings def evaluate(t): if clarikey(t) in V: return True if t.op == 'BoolV': V[clarikey(t)] = 1,t.args[0] return True if t.op == 'BVV': V[clarikey(t)] = t.size(),t.args[0] return True if t.op == 'BVS': if clarikey(t) not in replacements: return False V[clarikey(t)] = t.size(),replacements[clarikey(t)].args[0] return True if t.op == 'Extract': assert len(t.args) == 3 top = t.args[0] bot = t.args[1] if not evaluate(t.args[2]): return False x0 = V[clarikey(t.args[2])] assert x0[0] > top assert top >= bot assert bot >= 0 V[clarikey(t)] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot) return True if t.op in ('SignExt','ZeroExt'): assert len(t.args) == 2 if not evaluate(t.args[1]): return False x0bits,x0 = V[clarikey(t.args[1])] extend = t.args[0] assert extend >= 0 if t.op == 'SignExt': if x0 >= (1<<(x0bits-1)): x0 -= 1<<x0bits x0 += 1<<(x0bits+extend) V[clarikey(t)] = x0bits+extend,x0 return True for a in t.args: if not evaluate(a): return False x = [V[clarikey(a)] for a in t.args] if t.op == 'Concat': y = 0 ybits = 0 for xbitsi,xi in x: y <<= xbitsi y += xi ybits += xbitsi V[clarikey(t)] = ybits,y return True if t.op == 'Reverse': assert len(x) == 1 xbits0,x0 = x[0] ybits = xbits0 assert ybits%8 == 0 y = 0 for i in range(ybits//8): y = (y<<8)+((x0>>(8*i))&255) V[clarikey(t)] = ybits,y return True if t.op in ('__eq__','__ne__'): assert len(x) == 2 assert x[0][0] == x[1][0] if t.op == '__eq__': V[clarikey(t)] = 1,(x[0][1]==x[1][1]) elif t.op == '__ne__': V[clarikey(t)] = 1,(x[0][1]!=x[1][1]) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'): bits = x[0][0] assert all(xi[0] == bits for xi in x) if t.op == '__and__': reduction = (lambda s,t:s&t) elif t.op == '__or__': reduction = (lambda s,t:s|t) elif t.op == '__xor__': reduction = (lambda s,t:s^t) elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits)) elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits)) elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits)) elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits)) elif t.op == '__rshift__': def reduction(s,t): flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR usigned = ssigned >> tsigned return usigned%(2**bits) elif t.op == '__mul__': reduction = (lambda s,t:(s*t)%(2**bits)) W.add('mul') elif t.op == '__floordiv__': def reduction(s,t): if t == 0: return 0 return (s//t)%(2**bits) W.add('div') elif t.op == '__mod__': def reduction(s,t): if t == 0: return s return (s%t)%(2**bits) W.add('div') elif t.op == 'SDiv': def reduction(s,t): if t == 0: return 0 flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # sdiv definition in Z3: # - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0]. # - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0]. if ssigned*tsigned >= 0: usigned = ssigned // tsigned else: usigned = -((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') elif t.op == 'SMod': def reduction(s,t): if t == 0: return s flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # srem definition in Z3: # It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division. # The most significant bit (sign) of the result is equal to the most significant bit of \c t1. if ssigned*tsigned >= 0: usigned = ssigned - tsigned*(ssigned // tsigned) else: usigned = ssigned + tsigned*((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') else: print('values: internal error %s, falling back to Z3' % t.op) return False V[clarikey(t)] = bits,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == '__invert__': assert len(x) == 1 bits = x[0][0] V[clarikey(t)] = bits,(1<<bits)-1-x[0][1] return True if t.op == 'Not': assert len(x) == 1 assert all(xi[0] == 1 for xi in x) V[clarikey(t)] = 1,1-x[0][1] return True if t.op in ('And','Or'): assert all(xi[0] == 1 for xi in x) if t.op == 'And': reduction = (lambda s,t:s*t) elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t) else: print('values: internal error %s, falling back to Z3' % t.op) return False V[clarikey(t)] = 1,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == 'If': assert len(x) == 3 assert x[0][0] == 1 if x[0][1]: V[clarikey(t)] = x[1] else: V[clarikey(t)] = x[2] return True if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'): assert len(x) == 2 bits = x[0][0] assert bits == x[1][0] flip = 2**(bits-1) x0,x1 = x[0][1],x[1][1] if t.op == '__le__': V[clarikey(t)] = (1,x0<=x1) elif t.op == 'ULE': V[clarikey(t)] = (1,x0<=x1) elif t.op == '__lt__': V[clarikey(t)] = (1,x0<x1) elif t.op == 'ULT': V[clarikey(t)] = (1,x0<x1) elif t.op == '__ge__': V[clarikey(t)] = (1,x0>=x1) elif t.op == 'UGE': V[clarikey(t)] = (1,x0>=x1) elif t.op == '__gt__': V[clarikey(t)] = (1,x0>x1) elif t.op == 'UGT': V[clarikey(t)] = (1,x0>x1) elif t.op == 'SLE': V[clarikey(t)] = (1,(x0^flip)<=(x1^flip)) elif t.op == 'SLT': V[clarikey(t)] = (1,(x0^flip)<(x1^flip)) elif t.op == 'SGE': V[clarikey(t)] = (1,(x0^flip)>=(x1^flip)) elif t.op == 'SGT': V[clarikey(t)] = (1,(x0^flip)>(x1^flip)) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True # XXX: add support for more print('values: unsupported operation %s, falling back to Z3' % t.op) return False # XXX: also add more validation for all of the above for t in terms: if not evaluate(t): return None,W return V,W def unroll_print(outputs,unrolled,f): walked = {} def walk(t): if clarikey(t) in walked: return walked[clarikey(t)] if t.op == 'BoolV': walknext = len(walked)+1 f.write('v%d = bool(%d)\n' % (walknext,t.args[0])) elif t.op == 'BVV': walknext = len(walked)+1 f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0])) elif t.op == 'BVS': walknext = len(walked)+1 f.write('v%d = %s\n' % (walknext,t.args[0])) elif t.op == 'Extract': assert len(t.args) == 3 input = 'v%d' % walk(t.args[2]) walknext = len(walked)+1 f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1])) elif t.op in ['SignExt','ZeroExt']: assert len(t.args) == 2 input = 'v%d' % walk(t.args[1]) walknext = len(walked)+1 f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0])) else: inputs = ['v%d' % walk(a) for a in t.args] walknext = len(walked)+1 f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs))) walked[clarikey(t)] = walknext return walknext for x in unrolled: walk(x) unrolledpos = 0 for csymbol,symboltype,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos]))) unrolledpos += 1 def unroll_inputvars(inputs): result = [] for csymbol,symboltype,entries in inputs: for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) result += [(varname,variable)] return result # XXX: probably better to merge into unroll() def unroll_worker(binary,inputs,outputs,rust,compiler): elfulator = compiler_hastag(compiler,'elfulator') add_options = { angr.options.LAZY_SOLVES, angr.options.SYMBOLIC_WRITE_ADDRESSES, angr.options.CONSERVATIVE_READ_STRATEGY, angr.options.CONSERVATIVE_WRITE_STRATEGY, } remove_options = { angr.options.SIMPLIFY_CONSTRAINTS, # angr.options.SIMPLIFY_EXIT_GUARD, # angr.options.SIMPLIFY_EXIT_STATE, # angr.options.SIMPLIFY_EXIT_TARGET, angr.options.SIMPLIFY_EXPRS, angr.options.SIMPLIFY_MEMORY_READS, angr.options.SIMPLIFY_MEMORY_WRITES, angr.options.SIMPLIFY_REGISTER_READS, angr.options.SIMPLIFY_REGISTER_WRITES, angr.options.SIMPLIFY_RETS, } if elfulator: add_options |= {angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY} add_options |= {angr.options.ZERO_FILL_UNCONSTRAINED_REGISTERS} add_options |= angr.options.unicorn-{angr.options.UNICORN_SYM_REGS_SUPPORT} else: add_options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY} add_options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS} stdin = [] for csymbol,symboltype,entries in inputs: for i in range(entries): varname = 'in_%s_%d'%(csymbol,i) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) if compiler_endianness[compiler] == 'little': variable = claripy.Reverse(variable) stdin += [variable] stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*stdin),has_end=True) if elfulator: avoidsimprocedures = ( 'clock_gettime', '_setjmp', '__setjmp', '___setjmp', 'longjmp', '_longjmp', '__longjmp', '___longjmp', 'sigsetjmp', '_sigsetjmp', '__sigsetjmp', '___sigsetjmp', 'siglongjmp', '_siglongjmp', '__siglongjmp', '___siglongjmp', ) proj = angr.Project('elfulator',exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False,force_load_libs=['libunicorn.so']) class posix_memalign(angr.SimProcedure): def run(self,sim_ptr,sim_alignment,sim_size): result = self.state.heap._malloc(sim_size) self.state.memory.store(sim_ptr,result,size=8) return claripy.BVV(0,32) proj.hook_symbol('posix_memalign',posix_memalign()) class sysconf(angr.SimProcedure): def run(self,num): return claripy.BVV(os.sysconf(num.concrete_value),64) proj.hook_symbol('sysconf',sysconf()) class getpagesize(angr.SimProcedure): def run(self): return claripy.BVV(resource.getpagesize(),32) proj.hook_symbol('getpagesize',getpagesize()) class strerror(angr.SimProcedure): def run(self,num): e = os.strerror(num.concrete_value)+'\0' malloc = angr.SIM_PROCEDURES["libc"]["malloc"] where = self.inline_call(malloc,len(e)).ret_expr self.state.memory.store(where,claripy.BVV(e),size=len(e)) return where proj.hook_symbol('strerror',strerror()) proj.hook_symbol('mmap64',angr.procedures.posix.mmap.mmap()) with open(binary,'rb') as f: binarycontents = f.read() state = proj.factory.full_init_state(add_options=add_options,remove_options=remove_options,args=['elfulator',binary,str(len(binarycontents))],stdin=stdin) simfile = angr.SimFile(binary,content=binarycontents) simfile.set_state(state) state.fs.insert(binary,simfile) else: if rust: proj = angr.Project(binary,auto_load_libs=False) else: avoidsimprocedures = ( 'memcmp', 'bcmp', ) proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False) state = proj.factory.full_init_state(add_options=add_options,remove_options=remove_options,stdin=stdin) simgr = proj.factory.simgr(state) while True: if len(simgr.errored) > 0: return -1,False,simgr.errored if len(simgr.deadended)+len(simgr.active) > maxsplit: constraintbuf = io.StringIO() constraintnames = ('exited',None,len(simgr.deadended)),('active',None,len(simgr.active)) unroll_print(constraintnames,[claripy.And(*e.solver.constraints) for e in simgr.deadended+simgr.active],constraintbuf) return -1,False,f'saferewrite limiting split to {maxsplit}\n{constraintbuf.getvalue()}' if len(simgr.active) == 0: break if elfulator: # XXX: should find a documented interface to do this if any(type(e.posix.fd[0]._read_pos) != int for e in simgr.active): state.options -= angr.options.unicorn simgr.step() exits = simgr.deadended assert len(exits) > 0 ok = True comment = f'have {len(exits)} exits\n' for epos,e in enumerate(exits): receivedpackets = len(e.posix.stdout.content) expectedpackets = sum(entries for csymbol,symboltype,entries in outputs) if receivedpackets != expectedpackets: ok = False comment += f'exit {epos} stdout packets {receivedpackets} expecting {expectedpackets} stderr packets {len(e.posix.stderr.content)}\n' for packetpos,packet in enumerate(e.posix.stderr.content): if not (packet[0].concrete and packet[1].concrete): comment += f'stderr symbolic packet {packetpos}\n' continue x = packet[0].concrete_value n = packet[1].concrete_value todo = b'' for i in range(n): todo = bytes(bytearray([x&255]))+todo x >>= 8 for line in todo.splitlines(): comment += f'stderr packet {packetpos} line: {line}\n' if not ok: constraintbuf = io.StringIO() constraintnames = ('exited',None,len(exits)), unroll_print(constraintnames,[claripy.And(claripy.true,*e.solver.constraints) for e in exits],constraintbuf) return -1,False,comment+constraintbuf.getvalue() # cannot be safe if there are multiple exits # for equivalence tests we'll merge exits below if len(exits) > 1: mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits]) else: mergedexit = exits[0] results = [] packetpos = 0 for csymbol,symboltype,entries in outputs: for i in range(entries): packet = mergedexit.posix.stdout.content[packetpos] packetpos += 1 assert packet[1].concrete_value == typebits[symboltype]//8 xi = packet[0] if compiler_endianness[compiler] == 'little': xi = claripy.Reverse(xi) results += [xi] ispartition = True # XXX: check this return len(exits),ispartition,results def unroll(o,p,i,compiler): word = compiler_word(compiler) implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,word) logger = logging.getLogger() while len(logger.handlers) > 0: logger.removeHandler(logger.handlers[-1]) handler = logging.FileHandler(f'{builddir}/unroll-log') handler.setFormatter(angr.misc.loggers.CuteFormatter(should_color=False)) logger.addHandler(handler) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p] rust = 'Cargo.toml' in os.listdir(implementationdir) unrolltime = -cputime() numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs,rust,compiler) if numexits < 1: note(builddir,'warning-unrollerror',unrolled) return o,p,i,compiler,False if not ispartition: note(builddir,'warning-unrollnotpartition') return o,p,i,compiler,False if numexits > 1: note(builddir,'unsafe-unrollsplit-%d'%numexits) with open('%s/analysis/unrolled' % builddir,'w') as f: unroll_print(outputs,unrolled,f) okvars = set(vname for vname,v in unroll_inputvars(inputs)) usedvars = set(v for x in unrolled for v in x.variables) if not usedvars.issubset(okvars): note(builddir,'warning-unrollmem','\n'.join(sorted(usedvars-okvars))+'\n') if not okvars.issubset(usedvars): note(builddir,'warning-unusedinputs','\n'.join(sorted(okvars-usedvars))+'\n') Wdone = set() for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]): # cpu gave us outputs y given inputs x # does this match unrolled? replacements = {} xpos = 0 for csymbol,symboltype,entries in inputs: for e in range(entries): varname = 'in_%s_%d'%(csymbol,e) variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True) replacements[clarikey(variable)] = claripy.BVV(x[xpos],typebits[symboltype]) xpos += 1 assert xpos == len(x) mismatch = True notestr = 'internal error\n' V = None try: V,W = values(unrolled,replacements) for warn in W: if warn in Wdone: continue Wdone.add(warn) note(builddir,'warning-%s'%warn) except AssertionError: note(builddir,'warning-valuesfailed',traceback.format_exc()) # proceed with z3 fallback below if V is not None: # what does not exactly work: # mismatch = any(yi != V[clarikey(unrolledi)][1] for (yi,unrolledi) in zip(y,unrolled)) # because yi can be signed while V is always unsigned # so sweep through outputs (which we want to do anyway in case of mismatch) mismatch = False notestr = '# mismatch between CPU execution of binary and saferewrite execution of unrolled:\n' for (vname,v),xi in zip(unroll_inputvars(inputs),x): notestr += '%s = %d\n' % (vname,xi) pos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) executedvalue = y[pos] unrolledvalue = V[clarikey(unrolled[pos])][1] if symboltype in signedtypes: if unrolledvalue >= 2**(typebits[symboltype]-1): unrolledvalue -= 2**typebits[symboltype] if executedvalue != unrolledvalue: mismatch = True notestr += 'executed_%s = %s\n' % (varname,executedvalue) notestr += 'unrolled_%s = %s\n' % (varname,unrolledvalue) pos += 1 else: # fall back on Z3 for figuring this out s = claripy.Solver(timeout=z3timeout) mismatch = claripy.false for yi,unrolledi in zip(y,unrolled): mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi) s.add(mismatch) mismatch = s.satisfiable() if mismatch: notestr = '# hint: this type of mismatch typically reflects reading undefined memory.\n' notestr += '# mismatch between CPU execution of binary and Z3 execution of unrolled:\n' for vname,v in unroll_inputvars(inputs): notestr += '%s = %s\n' % (vname,s.eval(v,1)[0]) pos = 0 for csymbol,symboltype,entries in outputs: for e in range(entries): varname = 'out_%s_%d'%(csymbol,e) notestr += 'executed_%s = %s\n' % (varname,y[pos]) notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0]) pos += 1 if mismatch: note(builddir,'warning-unrollmismatch',notestr) return o,p,i,compiler,False unrolltime += cputime() notetime(builddir,'unroll',unrolltime) return o,p,i,compiler,unrolled def unroll_wrapper(o,p,i,compiler): word = compiler_word(compiler) builddir = 'build/%s/%s/%s' % (p,i,word) setproctitle.setproctitle(f'saferewrite unroll {builddir}') try: _,_,_,_,unrolled = unroll(o,p,i,compiler) with open(f'{builddir}/unroll-pickle','wb') as f: pickle.dump(unrolled,f) except: note(builddir,'warning-unrollexception',traceback.format_exc()) sys.exit(111) if unrolled == False: # warning already issued by unroll() sys.exit(111) def wanttounroll(): for o,p in primitives: for i,compiler in op_compiled[o,p]: if (o,p,i,compiler) in opic_y: yield o,p,i,compiler # seems more robust than multiprocessing for spawning angr workers def starmap_status(fun,todo): result = {} workers = {} todo = iter(todo) while True: if todo is None and len(workers) == 0: return result if len(workers) < os_threads: if todo is not None: todoentry = next(todo,None) if todoentry is None: todo = None if len(workers) == 0: return result else: pid = os.fork() if pid == 0: fun(*todoentry) sys.exit(0) workers[pid] = todoentry continue pid,status = os.wait() if pid in workers: result[workers[pid]] = status del workers[pid] opic_status = starmap_status(unroll_wrapper,wanttounroll()) for o,p,i,compiler in wanttounroll(): status = opic_status[o,p,i,compiler] if status not in (0,256*111): word = compiler_word(compiler) builddir = f'build/{p}/{i}/{word}' note(builddir,'warning-unrollexitcode',f'exit status {status}\n') print('===== equiv (can take tons of time)') def equiv(o,p,i,compiler,source,sourcecompiler): word = compiler_word(compiler) sourceword = compiler_word(sourcecompiler) implementationdir = '%s/%s/%s' % (o,p,i) builddir = 'build/%s/%s/%s' % (p,i,word) inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p] setproctitle.setproctitle(f'saferewrite equiv {builddir} build/{p}/{source}/{sourceword}') randomtestspass = True for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])): if y != z: xstr = input_example_str(inputs,x) ystr = output_example_str(outputs,y) zstr = output_example_str(outputs,z) note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourceword),xstr+ystr+zstr) randomtestspass = False # continue through random tests to get statistics if not randomtestspass: if not satvalidation1: return equivtime = -cputime() with open(f'build/{p}/{source}/{sourceword}/unroll-pickle','rb') as f: u1 = pickle.load(f) with open(f'build/{p}/{i}/{word}/unroll-pickle','rb') as f: u2 = pickle.load(f) assert len(u1) == len(u2) # XXX: allow other equivalence-testing techniques s = claripy.Solver(timeout=z3timeout) different = claripy.false for u1j,u2j in zip(u1,u2): different = claripy.Or(different,u1j != u2j) s.add(different) try: mismatch = s.satisfiable() except claripy.errors.ClaripyZ3Error: # avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887 note(builddir,'warning-z3failed',traceback.format_exc()) return if mismatch: # angr documentation says: # "If you don't add any constraints between two queries, the results will be consistent with each other." example = '' for vname,v in unroll_inputvars(inputs): example += '%s = %s\n' % (vname,s.eval(v,1)[0]) unrolledpos = 0 for csymbol,symboltype,entries in outputs: for i in range(entries): varname = 'out_%s_%d'%(csymbol,i) example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0]) example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0]) unrolledpos += 1 note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourceword),example) else: note(builddir,'equals-%s-%s' % (source,sourceword)) equivtime += cputime() notetime(builddir,'equiv',equivtime) def wanttoequiv(): for o,p in primitives: # XXX: allow each implementation to choose its source (if available) iclist = [] for i,compiler in op_compiled[o,p]: if opic_status[o,p,i,compiler] != 0: continue iclist += [(i!='ref',len(i),i,compiler_pos(compiler),compiler)] iclist.sort() iclist = [(i,compiler) for _,_,i,_,compiler in iclist] iclistpos = {key:pos for pos,key in enumerate(iclist)} for i,compiler in iclist: if i != 'ref': key = 'ref',compiler if key in iclistpos: if iclistpos[key] < iclistpos[i,compiler]: yield (o,p,i,compiler,*key) continue yield (o,p,i,compiler,*iclist[0]) # this is a self-test for the first compiler starmap_status(equiv,wanttoequiv())