-rwxr-xr-x 3271 nttcompiler-20220411/scripts/unroll
#!/usr/bin/env python3
import angr
import sys
sys.setrecursionlimit(100000)
typebits = {'int8':8,'int16':16,'int32':32,'int64':64}
binary = sys.argv[1]
args = sys.argv[2:]
wantexpr = 0
inputs = []
outputs = []
io = inputs
while len(args) > 0:
if args[0] == '--':
args = args[1:]
io = outputs
continue
if args[0] == 'expr':
wantexpr = 1
args = args[1:]
continue
assert len(args) >= 3
bitsperentry = typebits[args[0]]
csymbol = args[1]
entries = int(args[2])
args = args[3:]
io += [(csymbol,bitsperentry,entries)]
proj = angr.Project(binary)
state = proj.factory.full_init_state()
state.options |= {angr.options.LAZY_SOLVES}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS}
state.options -= {angr.options.SIMPLIFY_EXPRS}
state.options -= {angr.options.SIMPLIFY_REGISTER_WRITES}
state.options -= {angr.options.SIMPLIFY_MEMORY_WRITES}
state.options -= {angr.options.SIMPLIFY_REGISTER_READS}
state.options -= {angr.options.SIMPLIFY_MEMORY_READS}
for csymbol,bitsperentry,entries in inputs:
xaddr = proj.loader.find_symbol(csymbol).rebased_addr
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = state.solver.BVS(varname,bitsperentry)
if bitsperentry == 8:
state.mem[xaddr+i].char = variable
elif bitsperentry == 16:
state.mem[xaddr+2*i].short = variable
elif bitsperentry == 32:
state.mem[xaddr+4*i].int = variable
elif bitsperentry == 64:
state.mem[xaddr+8*i].long = variable
simgr = proj.factory.simgr(state)
simgr.run()
exits = simgr.deadended
assert len(exits) == 1
assert len(simgr.errored) == 0
walked = {}
walknext = 0
def walk(t):
global walknext
if t in walked: return walked[t]
if t.op == 'BVV':
walknext += 1
print('v%d = constant(%d,%d)' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext += 1
print('v%d = %s' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext += 1
print('v%d = Extract(%s,%d,%d)' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ['SignExt','ZeroExt']:
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext += 1
print('v%d = %s(%s,%d)' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext += 1
print('v%d = %s(%s)' % (walknext,t.op,','.join(inputs)))
walked[t] = walknext
return walknext
for expr in range(wantexpr+1):
if expr:
print('')
print('# expressions (can explode exponentially):')
for csymbol,bitsperentry,entries in outputs:
xaddr = proj.loader.find_symbol(csymbol).rebased_addr
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
if bitsperentry == 8:
x = exits[0].mem[xaddr+i].char.resolved
elif bitsperentry == 16:
x = exits[0].mem[xaddr+2*i].short.resolved
elif bitsperentry == 32:
x = exits[0].mem[xaddr+4*i].int.resolved
elif bitsperentry == 64:
x = exits[0].mem[xaddr+8*i].long.resolved
if expr:
print('# %s = %s' % (varname,x))
else:
print('%s = v%s' % (varname,walk(x)))
sys.stdout.flush()