-rwxr-xr-x 54459 saferewrite-20250228/analyze raw
#!/usr/bin/env python3
import sys
import os
import logging
import pickle
import io
import shutil
import subprocess
import multiprocessing
import random
import traceback
import functools
import setproctitle
import resource
import angr
import claripy
import struct
signedtypes = set(('int8','int16','int32','int64'))
typebits = {
'int8': 8,
'int16': 16,
'int32': 32,
'int64': 64,
'uint8': 8,
'uint16': 16,
'uint32': 32,
'uint64': 64,
}
rusttype = {
'int8': 'i8',
'int16': 'i16',
'int32': 'i32',
'int64': 'i64',
'uint8': 'u8',
'uint16': 'u16',
'uint32': 'u32',
'uint64': 'u64',
}
typestruct = {
'int8': 'b',
'int16': 'h',
'int32': 'i',
'int64': 'q',
'uint8': 'B',
'uint16': 'H',
'uint32': 'I',
'uint64': 'Q',
}
# ===== performance-related configuration...
numrandomtests = 16
# test 16 random inputs to each function
# (first two are all-0, all-1)
# of course more time on testing and fuzzing would be useful
maxsplit = 300
# max number of universes within an angr run
# XXX: allow per-primitive/per-implementation configuration
satvalidation1 = True
# True means: even if random tests fail, invoke general sat/unsat mechanism
# this allows extra validation of that mechanism
# the following is somewhat ad-hoc monkey-patching
# underlying problem: flatten simplifiers can easily blow up
try:
claripy.simplifications.simpleton._simplifiers = {
'Reverse': claripy.simplifications.simpleton.bv_reverse_simplifier,
'Extract': claripy.simplifications.simpleton.extract_simplifier,
'Concat': claripy.simplifications.simpleton.concat_simplifier,
'ZeroExt': claripy.simplifications.simpleton.zeroext_simplifier,
'SignExt': claripy.simplifications.simpleton.signext_simplifier,
}
except:
try:
claripy.simpleton._simplifiers = {
'Reverse': claripy.simpleton.bv_reverse_simplifier,
'Extract': claripy.simpleton.extract_simplifier,
'Concat': claripy.simpleton.concat_simplifier,
'ZeroExt': claripy.simpleton.zeroext_simplifier,
'SignExt': claripy.simpleton.signext_simplifier,
}
except:
claripy.simplifications._all_simplifiers = {
'Reverse': claripy.simplifications.bv_reverse_simplifier,
'Extract': claripy.simplifications.extract_simplifier,
'Concat': claripy.simplifications.concat_simplifier,
'ZeroExt': claripy.simplifications.zeroext_simplifier,
'SignExt': claripy.simplifications.signext_simplifier,
}
sys.setrecursionlimit(1000000)
# ===== real work begins here
z3timeout = 2**32-1
# XXX: by default claripy satisfiable() has 5-minute timeout
# and indicates timeout by returning False
# current Z3 documentation says:
# timeout (in milliseconds) (UINT_MAX and 0 mean no timeout)
try:
os_threads = len(os.sched_getaffinity(0))
except AttributeError:
os_threads = multiprocessing.cpu_count()
os_threads = os.getenv('THREADS',default=os_threads)
os_threads = int(os_threads)
if os_threads < 1: os_threads = 1
def cputime():
return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime
resource.setrlimit(resource.RLIMIT_NOFILE,2*resource.getrlimit(resource.RLIMIT_NOFILE)[1:])
def notetime(builddir,what,time):
print('%s seconds %s %.6f' % (builddir,what,time))
sys.stdout.flush()
with open('%s/analysis/seconds' % builddir,'a') as f:
f.write('%s %.6f\n' % (what,time))
def note(builddir,conclusion,contents=None):
print('%s %s' % (builddir,conclusion))
sys.stdout.flush()
with open('%s/analysis/%s' % (builddir,conclusion),'w') as f:
if contents is not None:
f.write(str(contents))
startdir = os.getcwd()
assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir)
shutil.rmtree('build',ignore_errors=True)
os.makedirs('build')
print('===== collecting compilers')
with open('compilers') as f:
compilers = f.read().splitlines()
compiler_pos_cached = {}
for pos,compiler in enumerate(compilers):
compiler_pos_cached[compiler] = pos
def compiler_pos(compiler):
return compiler_pos_cached[compiler]
def compiler_list():
for compiler in compilers:
compiler = compiler.strip()
if ':' not in compiler: continue
yield compiler
def compiler_command(compiler):
return ':'.join(compiler.split(':')[1:]).strip()
def compiler_hastag(compiler,tag):
return tag in compiler.split(':')[0].split()
def compiler_hasarch(compiler,arch):
if arch is None: return True
return any(all(compiler_hastag(compiler,x) for x in a.split()) for a in arch)
def compiler_word(compiler):
return compiler_command(compiler).replace(' ','_').replace('=','_')
compiler_endianness = {}
def checkcompiler(compiler):
word = compiler_word(compiler)
builddir = f'build/compiler/{word}'
os.makedirs(f'{builddir}/analysis',exist_ok=True)
language = 'cpp' if compiler_hastag(compiler,'cpp') else 'c'
with open(f'{builddir}/endianness.{language}','w') as f:
f.write('''#include <stdlib.h>
#include <inttypes.h>
int main()
{
uint16_t a = 0x0201;
uint32_t b = 0x04030201;
uint64_t c = 0x0807060504030201ULL;
char *A = (char *) &a;
char *B = (char *) &b;
char *C = (char *) &c;
if (A[0] == 1 && A[1] == 2 &&
B[0] == 1 && B[1] == 2 && B[2] == 3 && B[3] == 4 &&
C[0] == 1 && C[1] == 2 && C[2] == 3 && C[3] == 4 &&
C[4] == 5 && C[5] == 6 && C[6] == 7 && C[7] == 8)
return 12;
if (A[0] == 2 && A[1] == 1 &&
B[0] == 4 && B[1] == 3 && B[2] == 2 && B[3] == 1 &&
C[0] == 8 && C[1] == 7 && C[2] == 6 && C[3] == 5 &&
C[4] == 4 && C[5] == 3 && C[6] == 2 && C[7] == 1)
return 21;
return 99;
}
''')
command = f'{compiler_command(compiler)} -static -o endianness endianness.{language}'
try:
proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,f'warning-compilefailed-endianness.{language}',traceback.format_exc())
return compiler,None,False
assert not err
if out != '':
note(builddir,'warning-compileoutput-endianness.{language}',out)
if proc.returncode:
note(builddir,'warning-compilefailed-endianness.{language}',f'exit code {proc.returncode}\n')
return compiler,None,False
command = './endianness'
try:
proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,f'warning-executefailed-endianness',traceback.format_exc())
return compiler,None,False
assert not err
if out != '':
note(builddir,'warning-executeoutput-endianness',out)
if proc.returncode == 12:
note(builddir,f'endianness-little')
return compiler,'little',True
if proc.returncode == 21:
note(builddir,f'endianness-big')
return compiler,'big',True
note(builddir,'warning-executefailed-endianness',f'exit code {proc.returncode}\n')
return compiler,None,False
with multiprocessing.Pool(os_threads) as pool:
for compiler,endianness,ok in pool.map(checkcompiler,compiler_list(),chunksize=1):
if not ok: continue
compiler_endianness[compiler] = endianness
def compiler_list2():
for compiler in compilers:
if compiler in compiler_endianness:
yield compiler
print('===== collecting primitives')
primitives = []
for o in 'src',:
o = o.strip()
if o == '': continue
if not os.path.isdir(o): continue
if os.stat('%s' % o).st_mode & 0o1000 == 0o1000:
print('%s sticky, skipping' % o)
sys.stdout.flush()
continue
for p in sorted(os.listdir(o)):
if p == 'compiler': continue
if not os.path.isdir('%s/%s' % (o,p)): continue
if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000:
print('%s/%s sticky, skipping' % (o,p))
sys.stdout.flush()
continue
if not os.path.exists('%s/%s/api' % (o,p)):
print('%s/%s/api nonexistent, skipping' % (o,p))
sys.stdout.flush()
continue
primitives += [(o,p)]
op_api = {}
for o,p in primitives:
inputs = []
outputs = []
funargs = []
funargtypes = []
rustargs = []
rustargtypes = []
crustargs = []
crustargtypes = []
funname = None
funret = None
funrettype = 'void'
rustfunrettype = None
with open('%s/%s/api' % (o,p)) as f:
for line in f:
line = line.split()
if len(line) == 0: continue
if line[0] == 'call':
funname = line[1]
assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_' for x in funname)
if line[0] == 'return':
symboltype = line[1]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
entries = 1
outputs += [(csymbol,symboltype,entries)]
funret = 'var_%s'%csymbol
funrettype = f'{symboltype}_t'
rustfunrettype = rusttype[symboltype]
if line[0] in ('in','out','inout'):
symboltype = line[1]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
if len(line) == 3:
pointer = False
entries = 1
else:
pointer = True
entries = int(line[3])
if line[0] in ('in','inout'):
inputs += [(csymbol,symboltype,entries)]
if line[0] in ('out','inout'):
outputs += [(csymbol,symboltype,entries)]
if pointer:
funargs += ['var_%s'%csymbol]
if line[0] == 'in':
funargtypes += [f'const {symboltype}_t *']
else:
funargtypes += [f'{symboltype}_t *']
crustargs += ['c_%s'%csymbol]
rustargs += ['rust_%s'%csymbol]
if line[0] == 'in':
crustargtypes += ['*const %s' % rusttype[symboltype]]
else:
crustargtypes += ['*mut %s' % rusttype[symboltype]]
else:
funargs += ['var_%s[0]'%csymbol]
funargtypes += [f'{symboltype}_t']
crustargs += ['c_%s'%csymbol]
rustargs += ['rust_%s'%csymbol]
crustargtypes += ['const %s' % rusttype[symboltype]]
# XXX: support constant inputs
op_api[o,p] = inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype
def input_example_str(inputs,x):
xstr = ''
xpos = 0
for csymbol,symboltype,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
xstr += '%s = %d\n' % (varname,x[xpos])
xpos += 1
assert xpos == len(x)
return xstr
def output_example_str(outputs,y):
ystr = ''
ypos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
ystr += '%s = %d\n' % (varname,y[ypos])
ypos += 1
assert ypos == len(y)
return ystr
reservedfilenames = (
'library.so.1',
'unroll-pickle',
'unroll-log',
'analysis',
'analysis-execute',
'analysis-execute.c',
'analysis-execute.cpp',
'analysis-valgrind',
'analysis-valgrind.c',
'analysis-valgrind.cpp',
'analysis-angr',
'analysis-angr.c',
'analysis-angr.cpp',
'subroutines.c',
'subroutines.cpp',
)
opi_language = {}
opi_architectures = {}
opimplementations = {}
for o,p in primitives:
opimplementations[o,p] = []
for i in sorted(os.listdir('%s/%s' % (o,p))):
implementationdir = '%s/%s/%s' % (o,p,i)
if not os.path.isdir(implementationdir): continue
if os.stat(implementationdir).st_mode & 0o1000 == 0o1000:
print('%s/%s/%s sticky, skipping' % (o,p,i))
continue
files = sorted(os.listdir(implementationdir))
for f in files:
ok = True
if f in reservedfilenames:
print('%s/%s/%s/%s reserved filename' % (o,p,i,f))
ok = False
if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f):
print('%s/%s/%s/%s prohibited character' % (o,p,i,f))
ok = False
if not ok: continue
opimplementations[o,p] += [i]
language = 'c'
for f in files:
if f.endswith('cc') or f.endswith('.cpp'):
language = 'cpp'
opi_language[o,p,i] = language
opi_architectures[o,p,i] = None
if 'architectures' in files:
with open(f'{implementationdir}/architectures') as f:
opi_architectures[o,p,i] = f.read().splitlines()
def compile(o,p,i,compiler):
language = opi_language[o,p,i]
word = compiler_word(compiler)
implementationdir = '%s/%s/%s' % (o,p,i)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p]
elfulator = compiler_hastag(compiler,'elfulator')
files = sorted(os.listdir(implementationdir))
files += [f'subroutines.{language}']
rust = 'Cargo.toml' in files
rustlocked = 'Cargo.lock' in files
cfiles = [x for x in files if x.endswith('.c') or x.endswith('.cc') or x.endswith('.cpp')]
sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')]
files = cfiles + sfiles
afiles = list(files)
os.makedirs('build/%s/%s' % (p,i),exist_ok=True)
builddir = 'build/%s/%s/%s' % (p,i,word)
shutil.copytree(implementationdir,builddir)
os.makedirs('%s/analysis' % builddir)
for bits in 8,16,32,64:
if not os.path.exists(f'{builddir}/crypto_int{bits}.h'):
with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_int%d int%d_t' % (bits,bits))
if not os.path.exists(f'{builddir}/crypto_uint{bits}.h'):
with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_uint%d uint%d_t' % (bits,bits))
with open(f'{builddir}/subroutines.{language}','w') as f:
f.write('''#include <stddef.h>
int memcmp(const void *u,const void *v,size_t n)
{
unsigned char *x = (unsigned char *) u;
unsigned char *y = (unsigned char *) v;
for (;n > 0;--n,++x,++y) {
int result = *x;
result -= *y;
if (result) return result;
}
return 0;
}
int bcmp(const void *u,const void *v,size_t n)
{
return memcmp(u,v,n);
}
''')
for analysis in 'execute','valgrind','angr':
afiles += ['analysis-%s.%s' % (analysis,language)]
with open('%s/analysis-%s.%s' % (builddir,analysis,language),'w') as f:
if analysis != 'valgrind':
f.write('#include <unistd.h>\n')
f.write('#include <stdlib.h>\n')
f.write('#include <string.h>\n')
f.write('#include <inttypes.h>\n')
f.write('\n')
# function declaration
f.write('extern ')
if funrettype is not None:
f.write('%s ' % funrettype)
f.write('%s(%s);\n' % (funname,','.join(funargtypes)))
f.write('\n')
avoidalloc = (analysis == 'execute') or (analysis == 'angr' and elfulator)
if avoidalloc:
for csymbol,symboltype,entries in inputs:
f.write(f'{symboltype}_t var_{csymbol}[{entries}];\n')
for csymbol,symboltype,entries in outputs:
if (csymbol,symboltype,entries) not in inputs:
f.write(f'{symboltype}_t var_{csymbol}[{entries}];\n')
f.write('\n')
f.write('int main(int argc,char **argv)\n')
f.write('{\n')
if not avoidalloc:
for csymbol,symboltype,entries in inputs:
allocbytes = entries*typebits[symboltype]//8
f.write(f' {symboltype}_t *var_{csymbol} = ({symboltype}_t *) malloc({allocbytes});\n')
for csymbol,symboltype,entries in outputs:
if (csymbol,symboltype,entries) not in inputs:
allocbytes = entries*typebits[symboltype]//8
f.write(f' {symboltype}_t *var_{csymbol} = ({symboltype}_t *) malloc({allocbytes});\n')
f.write('\n')
# XXX: resource limits
if analysis != 'valgrind':
for csymbol,symboltype,entries in inputs:
allocbytes = typebits[symboltype]//8
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' if (read(0,(char *) &var_%s[i],%s) != %s)\n' % (csymbol,allocbytes,allocbytes))
f.write(' exit(111);\n')
f.write('\n')
f.write(' ')
if funret is not None:
f.write('%s[0] = ' % funret)
f.write('%s(%s);\n' % (funname,','.join(funargs)))
f.write('\n')
if analysis != 'valgrind':
for csymbol,symboltype,entries in outputs:
allocbytes = typebits[symboltype]//8
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' if (write(1,(char *) &var_%s[i],%s) != %s)\n' % (csymbol,allocbytes,allocbytes))
f.write(' exit(111);\n')
f.write('\n')
f.write(' return 0;\n')
f.write('}\n')
# ===== rust
if rust:
os.makedirs('%s/src/bin'%builddir,exist_ok=True)
with open('%s/Cargo.toml'%builddir,'a') as f:
f.write("""\
[[bin]]
name = "analysis-execute"
[[bin]]
name = "analysis-valgrind"
[[bin]]
name = "analysis-angr"
[profile.release]
panic = "abort"
""")
with open('%s/build.rs'%builddir,'w') as f:
f.write("""\
use std::env;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
println!("cargo:rustc-link-search=native={}",out_dir);
cc::Build::new()
.cargo_metadata(false)
.file("analysis-execute.c")
.compile("analysis-execute");
println!("cargo:rerun-if-changed=analysis-execute.c");
cc::Build::new()
.cargo_metadata(false)
.file("analysis-valgrind.c")
.compile("analysis-valgrind");
println!("cargo:rerun-if-changed=analysis-valgrind.c");
cc::Build::new()
.cargo_metadata(false)
.file("analysis-angr.c")
.compile("analysis-angr");
println!("cargo:rerun-if-changed=analysis-angr.c");
}
""")
for analysis in 'execute','valgrind','angr':
with open('%s/src/bin/analysis-%s.rs'%(builddir,analysis),'w') as f:
f.write('#![no_std]\n')
f.write('#![no_main]\n')
f.write('extern crate %s;\n'%funname)
f.write('\n')
if any(f'rust_{csymbol}' in rustargs for (csymbol,symboltype,entries) in outputs):
f.write('use core::convert::TryInto;\n')
f.write('use core::slice;\n')
f.write('\n')
f.write('#[link(name="analysis-%s")]\n'%analysis)
f.write('extern {\n')
f.write('#[allow(dead_code)]\n')
f.write(' fn main() -> i32;\n')
f.write('}\n')
f.write('\n')
f.write('#[no_mangle]\n')
f.write('pub unsafe extern "C"\n')
f.write('fn %s(\n'%funname)
for arg,argtype in zip(crustargs,crustargtypes):
f.write(' %s:%s,\n' % (arg,argtype))
if rustfunrettype is None:
f.write(') -> i32 {\n')
else:
f.write(') -> %s {\n' % rustfunrettype)
for csymbol,symboltype,entries in outputs:
if f'rust_{csymbol}' not in rustargs: continue
f.write(' let rust_%s = slice::from_raw_parts_mut(c_%s,%d);\n' % (csymbol,csymbol,entries))
f.write(' let rust_%s:&mut[%s;%d] = rust_%s.try_into().unwrap();\n' % (csymbol,rusttype[symboltype],entries,csymbol))
for csymbol,symboltype,entries in inputs:
if (csymbol,symboltype,entries) not in outputs:
f.write(' let rust_%s = slice::from_raw_parts(c_%s,%d);\n' % (csymbol,csymbol,entries))
if rustfunrettype is None:
f.write(' %s::%s(%s);\n' % (funname,funname,','.join(rustargs)))
f.write(' 0\n')
else:
f.write(' %s::%s(%s)\n' % (funname,funname,','.join(rustargs)))
f.write('}\n')
cargotime = -cputime()
env = os.environ
env['CC'] = compiler_command(compiler).split()[0]
env['CFLAGS'] = ' '.join(compiler_command(compiler).split()[1:])
env['RUSTFLAGS'] = '-C target_cpu=native -C link-args=-no-pie' # XXX: let user control this?
if rustlocked:
command = 'cargo build -q --locked --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr'
else:
command = 'cargo build -q --release --bin analysis-execute --bin analysis-valgrind --bin analysis-angr'
try:
proc = subprocess.Popen(command.split(),cwd=builddir,env=env,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,'warning-cargofailed',traceback.format_exc())
return o,p,i,compiler,False
assert not err
if out != '':
note(builddir,'warning-cargooutput',out)
if proc.returncode:
note(builddir,'warning-cargofailed','exit code %s\n' % proc.returncode)
return o,p,i,compiler,False
for analysis in 'execute','valgrind','angr':
shutil.copy('%s/target/release/analysis-%s'%(builddir,analysis),'%s/analysis-%s'%(builddir,analysis))
cargotime += cputime()
notetime(builddir,'cargo',cargotime)
return o,p,i,compiler,True
# ===== compile
compiletime = -cputime()
objfiles = []
for f in afiles:
command = '%s -Wall -Wextra -Wno-unused-function -Wno-unused-parameter -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler_command(compiler),f)
if f in ('analysis-valgrind.c','analysis-valgrind.cpp'):
command += ' -Wno-uninitialized'
try:
proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,'warning-compilefailed-%s'%f,traceback.format_exc())
return o,p,i,compiler,False
assert not err
if out != '':
note(builddir,'warning-compileoutput-%s'%f,out)
if proc.returncode:
note(builddir,'warning-compilefailed-%s'%f,'exit code %s\n' % proc.returncode)
return o,p,i,compiler,False
if f in files:
objfiles += ['.'.join(f.split('.')[:-1]+['o'])]
compiletime += cputime()
notetime(builddir,'compile',compiletime)
# ===== link into executable
linktime = -cputime()
for analysis in 'execute','valgrind','angr':
directlink = True
if directlink:
command = f'{compiler_command(compiler)} -no-pie -o analysis-{analysis} analysis-{analysis}.o'
command = command.split()
command += objfiles
else:
command = f'{compiler_command(compiler)} -shared -Wl,-soname,library.so.1 -o library.so.1'
command = command.split()
command += objfiles
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput-%s'%analysis,out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s\n' % proc.returncode)
return o,p,i,compiler,False
shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir)
command = f'{compiler_command(compiler)} -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir)
command = command.split()
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
note(builddir,'warning-linkfailed-%s'%analysis,traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput-%s'%analysis,out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed-%s'%analysis,'exit code %s\n' % proc.returncode)
return o,p,i,compiler,False
linktime += cputime()
notetime(builddir,'link',linktime)
return o,p,i,compiler,True
def wanttocompile():
for o,p in primitives:
for i in opimplementations[o,p]:
for compiler in compiler_list2():
if compiler_hastag(compiler,opi_language[o,p,i]):
if compiler_hasarch(compiler,opi_architectures[o,p,i]):
yield o,p,i,compiler
op_compiled = {}
for o,p in primitives:
op_compiled[o,p] = []
with multiprocessing.Pool(os_threads) as pool:
for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile(),chunksize=1):
if not ok: continue
op_compiled[o,p] += [(i,compiler)]
print('===== execute')
op_x = {}
for o,p in primitives:
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p]
op_x[o,p] = []
for execution in range(numrandomtests):
x = []
for csymbol,symboltype,entries in inputs:
for e in range(entries):
if execution == 0:
r = 0
elif execution == 1:
r = -1 if symboltype in signedtypes else 2**typebits[symboltype]-1
else:
r = random.randrange(2**typebits[symboltype])
if symboltype in signedtypes: r -= 2**(typebits[symboltype]-1)
x += [r]
op_x[o,p] += [x]
def execute(o,p,i,compiler):
word = compiler_word(compiler)
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,word)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p]
executetime = -cputime()
results = []
command = ['./analysis-execute']
for x in op_x[o,p]:
xstr = b''
xpos = 0
for csymbol,symboltype,entries in inputs:
for e in range(entries):
if symboltype in signedtypes:
assert x[xpos] >= -2**(typebits[symboltype]-1)
assert x[xpos] < 2**(typebits[symboltype]-1)
else:
assert x[xpos] >= 0
assert x[xpos] < 2**typebits[symboltype]
structfmt = '<' if compiler_endianness[compiler] == 'little' else '>'
structfmt += typestruct[symboltype]
xstr += struct.pack(structfmt,x[xpos])
xpos += 1
assert xpos == len(x)
try:
proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
ystr,err = proc.communicate(input=xstr)
except (OSError,FileNotFoundError):
note(builddir,'warning-executeerror',xstr)
return o,p,i,compiler,False
if proc.returncode != 0:
note(builddir,'warning-executefailed',f'exit code {proc.returncode}\n'+input_example_str(inputs,x))
return o,p,i,compiler,False
try:
y = []
ystrpos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
structfmt = '<' if compiler_endianness[compiler] == 'little' else '>'
structfmt += typestruct[symboltype]
structbytes = typebits[symboltype]//8
y += list(struct.unpack(structfmt,ystr[ystrpos:ystrpos+structbytes]))
if symboltype in signedtypes:
assert y[-1] >= -2**(typebits[symboltype]-1)
assert y[-1] < 2**(typebits[symboltype]-1)
else:
assert y[-1] >= 0
assert y[-1] < 2**typebits[symboltype]
ystrpos += structbytes
assert ystrpos == len(ystr)
except ValueError:
note(builddir,'warning-executebadformat',input_example_str(inputs,x))
return o,p,i,compiler,False
results += [y]
executetime += cputime()
notetime(builddir,'execute',executetime)
return o,p,i,compiler,results
def wanttoexecute():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
yield o,p,i,compiler
opic_y = {}
with multiprocessing.Pool(os_threads) as pool:
for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute(),chunksize=1):
if results == False: continue
opic_y[o,p,i,compiler] = results
print('===== valgrind (can take some time)')
def valgrind(o,p,i,compiler):
word = compiler_word(compiler)
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,word)
valgrindtime = -cputime()
command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind']
valgrindstatus = None
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except (OSError,FileNotFoundError):
valgrindstatus = 'warning-valgrinderror'
if valgrindstatus is None:
assert not err
if proc.returncode == 99:
valgrindstatus = 'unsafe-valgrindfailure'
elif proc.returncode != 0:
valgrindstatus = 'warning-valgrinderror'
elif out.find('client request') >= 0:
valgrindstatus = 'unsafe-valgrindfailure'
else:
valgrindstatus = 'passed-valgrind'
if valgrindstatus is not None:
note(builddir,valgrindstatus)
valgrindtime += cputime()
notetime(builddir,'valgrind',valgrindtime)
def wanttovalgrind():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
if compiler_hastag(compiler,'valgrind'):
yield o,p,i,compiler
with multiprocessing.Pool(os_threads) as pool:
list(pool.starmap(valgrind,wanttovalgrind(),chunksize=1))
print('===== unroll (can take tons of time)')
# XXX: could do this in parallel with valgrind
def clarikey(e):
try:
# e.g. angr 9.2.102
return e.cache_key
except:
# e.g. angr 9.2.144
return e.hash()
def values(terms,replacements):
# input: replacements mapping clarikey to integers
# output: dictionary V mapping clarikey to pairs (b,i) where i is a b-bit value
# output includes all terms
# _or_ output is None if terms use variables outside replacements
V = {}
W = set() # warnings
def evaluate(t):
if clarikey(t) in V:
return True
if t.op == 'BoolV':
V[clarikey(t)] = 1,t.args[0]
return True
if t.op == 'BVV':
V[clarikey(t)] = t.size(),t.args[0]
return True
if t.op == 'BVS':
if clarikey(t) not in replacements: return False
V[clarikey(t)] = t.size(),replacements[clarikey(t)].args[0]
return True
if t.op == 'Extract':
assert len(t.args) == 3
top = t.args[0]
bot = t.args[1]
if not evaluate(t.args[2]): return False
x0 = V[clarikey(t.args[2])]
assert x0[0] > top
assert top >= bot
assert bot >= 0
V[clarikey(t)] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot)
return True
if t.op in ('SignExt','ZeroExt'):
assert len(t.args) == 2
if not evaluate(t.args[1]): return False
x0bits,x0 = V[clarikey(t.args[1])]
extend = t.args[0]
assert extend >= 0
if t.op == 'SignExt':
if x0 >= (1<<(x0bits-1)):
x0 -= 1<<x0bits
x0 += 1<<(x0bits+extend)
V[clarikey(t)] = x0bits+extend,x0
return True
for a in t.args:
if not evaluate(a): return False
x = [V[clarikey(a)] for a in t.args]
if t.op == 'Concat':
y = 0
ybits = 0
for xbitsi,xi in x:
y <<= xbitsi
y += xi
ybits += xbitsi
V[clarikey(t)] = ybits,y
return True
if t.op == 'Reverse':
assert len(x) == 1
xbits0,x0 = x[0]
ybits = xbits0
assert ybits%8 == 0
y = 0
for i in range(ybits//8):
y = (y<<8)+((x0>>(8*i))&255)
V[clarikey(t)] = ybits,y
return True
if t.op in ('__eq__','__ne__'):
assert len(x) == 2
assert x[0][0] == x[1][0]
if t.op == '__eq__': V[clarikey(t)] = 1,(x[0][1]==x[1][1])
elif t.op == '__ne__': V[clarikey(t)] = 1,(x[0][1]!=x[1][1])
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'):
bits = x[0][0]
assert all(xi[0] == bits for xi in x)
if t.op == '__and__': reduction = (lambda s,t:s&t)
elif t.op == '__or__': reduction = (lambda s,t:s|t)
elif t.op == '__xor__': reduction = (lambda s,t:s^t)
elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits))
elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits))
elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits))
elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits))
elif t.op == '__rshift__':
def reduction(s,t):
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR
usigned = ssigned >> tsigned
return usigned%(2**bits)
elif t.op == '__mul__':
reduction = (lambda s,t:(s*t)%(2**bits))
W.add('mul')
elif t.op == '__floordiv__':
def reduction(s,t):
if t == 0: return 0
return (s//t)%(2**bits)
W.add('div')
elif t.op == '__mod__':
def reduction(s,t):
if t == 0: return s
return (s%t)%(2**bits)
W.add('div')
elif t.op == 'SDiv':
def reduction(s,t):
if t == 0: return 0
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# sdiv definition in Z3:
# - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0].
# - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0].
if ssigned*tsigned >= 0:
usigned = ssigned // tsigned
else:
usigned = -((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
elif t.op == 'SMod':
def reduction(s,t):
if t == 0: return s
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# srem definition in Z3:
# It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division.
# The most significant bit (sign) of the result is equal to the most significant bit of \c t1.
if ssigned*tsigned >= 0:
usigned = ssigned - tsigned*(ssigned // tsigned)
else:
usigned = ssigned + tsigned*((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[clarikey(t)] = bits,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == '__invert__':
assert len(x) == 1
bits = x[0][0]
V[clarikey(t)] = bits,(1<<bits)-1-x[0][1]
return True
if t.op == 'Not':
assert len(x) == 1
assert all(xi[0] == 1 for xi in x)
V[clarikey(t)] = 1,1-x[0][1]
return True
if t.op in ('And','Or'):
assert all(xi[0] == 1 for xi in x)
if t.op == 'And': reduction = (lambda s,t:s*t)
elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t)
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[clarikey(t)] = 1,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == 'If':
assert len(x) == 3
assert x[0][0] == 1
if x[0][1]:
V[clarikey(t)] = x[1]
else:
V[clarikey(t)] = x[2]
return True
if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'):
assert len(x) == 2
bits = x[0][0]
assert bits == x[1][0]
flip = 2**(bits-1)
x0,x1 = x[0][1],x[1][1]
if t.op == '__le__': V[clarikey(t)] = (1,x0<=x1)
elif t.op == 'ULE': V[clarikey(t)] = (1,x0<=x1)
elif t.op == '__lt__': V[clarikey(t)] = (1,x0<x1)
elif t.op == 'ULT': V[clarikey(t)] = (1,x0<x1)
elif t.op == '__ge__': V[clarikey(t)] = (1,x0>=x1)
elif t.op == 'UGE': V[clarikey(t)] = (1,x0>=x1)
elif t.op == '__gt__': V[clarikey(t)] = (1,x0>x1)
elif t.op == 'UGT': V[clarikey(t)] = (1,x0>x1)
elif t.op == 'SLE': V[clarikey(t)] = (1,(x0^flip)<=(x1^flip))
elif t.op == 'SLT': V[clarikey(t)] = (1,(x0^flip)<(x1^flip))
elif t.op == 'SGE': V[clarikey(t)] = (1,(x0^flip)>=(x1^flip))
elif t.op == 'SGT': V[clarikey(t)] = (1,(x0^flip)>(x1^flip))
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
# XXX: add support for more
print('values: unsupported operation %s, falling back to Z3' % t.op)
return False
# XXX: also add more validation for all of the above
for t in terms:
if not evaluate(t): return None,W
return V,W
def unroll_print(outputs,unrolled,f):
walked = {}
def walk(t):
if clarikey(t) in walked: return walked[clarikey(t)]
if t.op == 'BoolV':
walknext = len(walked)+1
f.write('v%d = bool(%d)\n' % (walknext,t.args[0]))
elif t.op == 'BVV':
walknext = len(walked)+1
f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext = len(walked)+1
f.write('v%d = %s\n' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext = len(walked)+1
f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ['SignExt','ZeroExt']:
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext = len(walked)+1
f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext = len(walked)+1
f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs)))
walked[clarikey(t)] = walknext
return walknext
for x in unrolled:
walk(x)
unrolledpos = 0
for csymbol,symboltype,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos])))
unrolledpos += 1
def unroll_inputvars(inputs):
result = []
for csymbol,symboltype,entries in inputs:
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
result += [(varname,variable)]
return result
# XXX: probably better to merge into unroll()
def unroll_worker(binary,inputs,outputs,rust,compiler):
elfulator = compiler_hastag(compiler,'elfulator')
add_options = {
angr.options.LAZY_SOLVES,
angr.options.SYMBOLIC_WRITE_ADDRESSES,
angr.options.CONSERVATIVE_READ_STRATEGY,
angr.options.CONSERVATIVE_WRITE_STRATEGY,
}
remove_options = {
angr.options.SIMPLIFY_CONSTRAINTS,
# angr.options.SIMPLIFY_EXIT_GUARD,
# angr.options.SIMPLIFY_EXIT_STATE,
# angr.options.SIMPLIFY_EXIT_TARGET,
angr.options.SIMPLIFY_EXPRS,
angr.options.SIMPLIFY_MEMORY_READS,
angr.options.SIMPLIFY_MEMORY_WRITES,
angr.options.SIMPLIFY_REGISTER_READS,
angr.options.SIMPLIFY_REGISTER_WRITES,
angr.options.SIMPLIFY_RETS,
}
if elfulator:
add_options |= {angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY}
add_options |= {angr.options.ZERO_FILL_UNCONSTRAINED_REGISTERS}
add_options |= angr.options.unicorn-{angr.options.UNICORN_SYM_REGS_SUPPORT}
else:
add_options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY}
add_options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS}
stdin = []
for csymbol,symboltype,entries in inputs:
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
if compiler_endianness[compiler] == 'little':
variable = claripy.Reverse(variable)
stdin += [variable]
stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*stdin),has_end=True)
if elfulator:
avoidsimprocedures = (
'clock_gettime',
'_setjmp',
'__setjmp',
'___setjmp',
'longjmp',
'_longjmp',
'__longjmp',
'___longjmp',
'sigsetjmp',
'_sigsetjmp',
'__sigsetjmp',
'___sigsetjmp',
'siglongjmp',
'_siglongjmp',
'__siglongjmp',
'___siglongjmp',
)
proj = angr.Project('elfulator',exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False,force_load_libs=['libunicorn.so'])
class posix_memalign(angr.SimProcedure):
def run(self,sim_ptr,sim_alignment,sim_size):
result = self.state.heap._malloc(sim_size)
self.state.memory.store(sim_ptr,result,size=8)
return claripy.BVV(0,32)
proj.hook_symbol('posix_memalign',posix_memalign())
class sysconf(angr.SimProcedure):
def run(self,num):
return claripy.BVV(os.sysconf(num.concrete_value),64)
proj.hook_symbol('sysconf',sysconf())
class getpagesize(angr.SimProcedure):
def run(self):
return claripy.BVV(resource.getpagesize(),32)
proj.hook_symbol('getpagesize',getpagesize())
class strerror(angr.SimProcedure):
def run(self,num):
e = os.strerror(num.concrete_value)+'\0'
malloc = angr.SIM_PROCEDURES["libc"]["malloc"]
where = self.inline_call(malloc,len(e)).ret_expr
self.state.memory.store(where,claripy.BVV(e),size=len(e))
return where
proj.hook_symbol('strerror',strerror())
proj.hook_symbol('mmap64',angr.procedures.posix.mmap.mmap())
with open(binary,'rb') as f:
binarycontents = f.read()
state = proj.factory.full_init_state(add_options=add_options,remove_options=remove_options,args=['elfulator',binary,str(len(binarycontents))],stdin=stdin)
simfile = angr.SimFile(binary,content=binarycontents)
simfile.set_state(state)
state.fs.insert(binary,simfile)
else:
if rust:
proj = angr.Project(binary,auto_load_libs=False)
else:
avoidsimprocedures = (
'memcmp',
'bcmp',
)
proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False)
state = proj.factory.full_init_state(add_options=add_options,remove_options=remove_options,stdin=stdin)
simgr = proj.factory.simgr(state)
while True:
if len(simgr.errored) > 0:
return -1,False,simgr.errored
if len(simgr.deadended)+len(simgr.active) > maxsplit:
constraintbuf = io.StringIO()
constraintnames = ('exited',None,len(simgr.deadended)),('active',None,len(simgr.active))
unroll_print(constraintnames,[claripy.And(*e.solver.constraints) for e in simgr.deadended+simgr.active],constraintbuf)
return -1,False,f'saferewrite limiting split to {maxsplit}\n{constraintbuf.getvalue()}'
if len(simgr.active) == 0:
break
if elfulator:
# XXX: should find a documented interface to do this
if any(type(e.posix.fd[0]._read_pos) != int for e in simgr.active):
state.options -= angr.options.unicorn
simgr.step()
exits = simgr.deadended
assert len(exits) > 0
ok = True
comment = f'have {len(exits)} exits\n'
for epos,e in enumerate(exits):
receivedpackets = len(e.posix.stdout.content)
expectedpackets = sum(entries for csymbol,symboltype,entries in outputs)
if receivedpackets != expectedpackets:
ok = False
comment += f'exit {epos} stdout packets {receivedpackets} expecting {expectedpackets} stderr packets {len(e.posix.stderr.content)}\n'
for packetpos,packet in enumerate(e.posix.stderr.content):
if not (packet[0].concrete and packet[1].concrete):
comment += f'stderr symbolic packet {packetpos}\n'
continue
x = packet[0].concrete_value
n = packet[1].concrete_value
todo = b''
for i in range(n):
todo = bytes(bytearray([x&255]))+todo
x >>= 8
for line in todo.splitlines():
comment += f'stderr packet {packetpos} line: {line}\n'
if not ok:
constraintbuf = io.StringIO()
constraintnames = ('exited',None,len(exits)),
unroll_print(constraintnames,[claripy.And(claripy.true,*e.solver.constraints) for e in exits],constraintbuf)
return -1,False,comment+constraintbuf.getvalue()
# cannot be safe if there are multiple exits
# for equivalence tests we'll merge exits below
if len(exits) > 1:
mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits])
else:
mergedexit = exits[0]
results = []
packetpos = 0
for csymbol,symboltype,entries in outputs:
for i in range(entries):
packet = mergedexit.posix.stdout.content[packetpos]
packetpos += 1
assert packet[1].concrete_value == typebits[symboltype]//8
xi = packet[0]
if compiler_endianness[compiler] == 'little':
xi = claripy.Reverse(xi)
results += [xi]
ispartition = True
# XXX: check this
return len(exits),ispartition,results
def unroll(o,p,i,compiler):
word = compiler_word(compiler)
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,word)
logger = logging.getLogger()
while len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[-1])
handler = logging.FileHandler(f'{builddir}/unroll-log')
handler.setFormatter(angr.misc.loggers.CuteFormatter(should_color=False))
logger.addHandler(handler)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p]
rust = 'Cargo.toml' in os.listdir(implementationdir)
unrolltime = -cputime()
numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs,rust,compiler)
if numexits < 1:
note(builddir,'warning-unrollerror',unrolled)
return o,p,i,compiler,False
if not ispartition:
note(builddir,'warning-unrollnotpartition')
return o,p,i,compiler,False
if numexits > 1:
note(builddir,'unsafe-unrollsplit-%d'%numexits)
with open('%s/analysis/unrolled' % builddir,'w') as f:
unroll_print(outputs,unrolled,f)
okvars = set(vname for vname,v in unroll_inputvars(inputs))
usedvars = set(v for x in unrolled for v in x.variables)
if not usedvars.issubset(okvars):
note(builddir,'warning-unrollmem','\n'.join(sorted(usedvars-okvars))+'\n')
if not okvars.issubset(usedvars):
note(builddir,'warning-unusedinputs','\n'.join(sorted(okvars-usedvars))+'\n')
Wdone = set()
for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]):
# cpu gave us outputs y given inputs x
# does this match unrolled?
replacements = {}
xpos = 0
for csymbol,symboltype,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
variable = claripy.BVS(varname,typebits[symboltype],explicit_name=True)
replacements[clarikey(variable)] = claripy.BVV(x[xpos],typebits[symboltype])
xpos += 1
assert xpos == len(x)
mismatch = True
notestr = 'internal error\n'
V = None
try:
V,W = values(unrolled,replacements)
for warn in W:
if warn in Wdone: continue
Wdone.add(warn)
note(builddir,'warning-%s'%warn)
except AssertionError:
note(builddir,'warning-valuesfailed',traceback.format_exc())
# proceed with z3 fallback below
if V is not None:
# what does not exactly work:
# mismatch = any(yi != V[clarikey(unrolledi)][1] for (yi,unrolledi) in zip(y,unrolled))
# because yi can be signed while V is always unsigned
# so sweep through outputs (which we want to do anyway in case of mismatch)
mismatch = False
notestr = '# mismatch between CPU execution of binary and saferewrite execution of unrolled:\n'
for (vname,v),xi in zip(unroll_inputvars(inputs),x):
notestr += '%s = %d\n' % (vname,xi)
pos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
executedvalue = y[pos]
unrolledvalue = V[clarikey(unrolled[pos])][1]
if symboltype in signedtypes:
if unrolledvalue >= 2**(typebits[symboltype]-1):
unrolledvalue -= 2**typebits[symboltype]
if executedvalue != unrolledvalue:
mismatch = True
notestr += 'executed_%s = %s\n' % (varname,executedvalue)
notestr += 'unrolled_%s = %s\n' % (varname,unrolledvalue)
pos += 1
else:
# fall back on Z3 for figuring this out
s = claripy.Solver(timeout=z3timeout)
mismatch = claripy.false
for yi,unrolledi in zip(y,unrolled):
mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi)
s.add(mismatch)
mismatch = s.satisfiable()
if mismatch:
notestr = '# hint: this type of mismatch typically reflects reading undefined memory.\n'
notestr += '# mismatch between CPU execution of binary and Z3 execution of unrolled:\n'
for vname,v in unroll_inputvars(inputs):
notestr += '%s = %s\n' % (vname,s.eval(v,1)[0])
pos = 0
for csymbol,symboltype,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
notestr += 'executed_%s = %s\n' % (varname,y[pos])
notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0])
pos += 1
if mismatch:
note(builddir,'warning-unrollmismatch',notestr)
return o,p,i,compiler,False
unrolltime += cputime()
notetime(builddir,'unroll',unrolltime)
return o,p,i,compiler,unrolled
def unroll_wrapper(o,p,i,compiler):
word = compiler_word(compiler)
builddir = 'build/%s/%s/%s' % (p,i,word)
setproctitle.setproctitle(f'saferewrite unroll {builddir}')
try:
_,_,_,_,unrolled = unroll(o,p,i,compiler)
with open(f'{builddir}/unroll-pickle','wb') as f:
pickle.dump(unrolled,f)
except:
note(builddir,'warning-unrollexception',traceback.format_exc())
sys.exit(111)
if unrolled == False: # warning already issued by unroll()
sys.exit(111)
def wanttounroll():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
if (o,p,i,compiler) in opic_y:
yield o,p,i,compiler
# seems more robust than multiprocessing for spawning angr workers
def starmap_status(fun,todo):
result = {}
workers = {}
todo = iter(todo)
while True:
if todo is None and len(workers) == 0:
return result
if len(workers) < os_threads:
if todo is not None:
todoentry = next(todo,None)
if todoentry is None:
todo = None
if len(workers) == 0:
return result
else:
pid = os.fork()
if pid == 0:
fun(*todoentry)
sys.exit(0)
workers[pid] = todoentry
continue
pid,status = os.wait()
if pid in workers:
result[workers[pid]] = status
del workers[pid]
opic_status = starmap_status(unroll_wrapper,wanttounroll())
for o,p,i,compiler in wanttounroll():
status = opic_status[o,p,i,compiler]
if status not in (0,256*111):
word = compiler_word(compiler)
builddir = f'build/{p}/{i}/{word}'
note(builddir,'warning-unrollexitcode',f'exit status {status}\n')
print('===== equiv (can take tons of time)')
def equiv(o,p,i,compiler,source,sourcecompiler):
word = compiler_word(compiler)
sourceword = compiler_word(sourcecompiler)
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,word)
inputs,outputs,funargs,funargtypes,crustargs,crustargtypes,rustargs,funname,funret,funrettype,rustfunrettype = op_api[o,p]
setproctitle.setproctitle(f'saferewrite equiv {builddir} build/{p}/{source}/{sourceword}')
randomtestspass = True
for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])):
if y != z:
xstr = input_example_str(inputs,x)
ystr = output_example_str(outputs,y)
zstr = output_example_str(outputs,z)
note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourceword),xstr+ystr+zstr)
randomtestspass = False
# continue through random tests to get statistics
if not randomtestspass:
if not satvalidation1:
return
equivtime = -cputime()
with open(f'build/{p}/{source}/{sourceword}/unroll-pickle','rb') as f:
u1 = pickle.load(f)
with open(f'build/{p}/{i}/{word}/unroll-pickle','rb') as f:
u2 = pickle.load(f)
assert len(u1) == len(u2)
# XXX: allow other equivalence-testing techniques
s = claripy.Solver(timeout=z3timeout)
different = claripy.false
for u1j,u2j in zip(u1,u2):
different = claripy.Or(different,u1j != u2j)
s.add(different)
try:
mismatch = s.satisfiable()
except claripy.errors.ClaripyZ3Error:
# avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887
note(builddir,'warning-z3failed',traceback.format_exc())
return
if mismatch:
# angr documentation says:
# "If you don't add any constraints between two queries, the results will be consistent with each other."
example = ''
for vname,v in unroll_inputvars(inputs):
example += '%s = %s\n' % (vname,s.eval(v,1)[0])
unrolledpos = 0
for csymbol,symboltype,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0])
example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0])
unrolledpos += 1
note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourceword),example)
else:
note(builddir,'equals-%s-%s' % (source,sourceword))
equivtime += cputime()
notetime(builddir,'equiv',equivtime)
def wanttoequiv():
for o,p in primitives:
# XXX: allow each implementation to choose its source (if available)
iclist = []
for i,compiler in op_compiled[o,p]:
if opic_status[o,p,i,compiler] != 0: continue
iclist += [(i!='ref',len(i),i,compiler_pos(compiler),compiler)]
iclist.sort()
iclist = [(i,compiler) for _,_,i,_,compiler in iclist]
iclistpos = {key:pos for pos,key in enumerate(iclist)}
for i,compiler in iclist:
if i != 'ref':
key = 'ref',compiler
if key in iclistpos:
if iclistpos[key] < iclistpos[i,compiler]:
yield (o,p,i,compiler,*key)
continue
yield (o,p,i,compiler,*iclist[0])
# this is a self-test for the first compiler
starmap_status(equiv,wanttoequiv())