-rwxr-xr-x 15821 nttcompiler-20220411/scripts/opt2range
#!/usr/bin/env python3
import sys
import math
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
modulus = int(sys.argv[1])
inputlow = int(sys.argv[2])
inputhigh = int(sys.argv[3])
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + Literal('constant').suppress()
+ lparen + number + comma + number + rparen
).setParseAction(group('constant'))
for binary in ['__sub__','__rshift__','__lshift__','mulhi16','mulhrs16']:
assignment |= (
name + equal + Literal(binary).suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group(binary))
assignment |= (
name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignment |= (
name + equal + Literal('SignExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('SignExt'))
assignment |= (
name + equal + Literal('ZeroExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('ZeroExt'))
for manyary in ['Concat','__or__','__and__','__add__','__mul__']:
assignment |= (
name + equal + Literal(manyary).suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group(manyary))
assignment |= (
name + equal + name).setParseAction(group('copy')
)
assignments = ZeroOrMore(assignment) + StringEnd()
program = ''
for line in sys.stdin:
if line.startswith('#'): continue
program += line+'\n'
program = assignments.parseString(program)
program = list(program)
nextvalue = 0
# indexed by variable name:
value = {}
# indexed by value:
operation = {}
parents = {}
bits = {}
low = {}
high = {}
formula = set()
def input(v):
global nextvalue
y = v.split('_')
nextvalue += 1
value[v] = nextvalue
operation[nextvalue] = ['input',v]
parents[nextvalue] = []
bits[nextvalue] = int(y[-1])
low[nextvalue] = inputlow
high[nextvalue] = inputhigh
def exact(v,L,H):
b = bits[v]
assert L <= H
return L >= -2**(b-1) and H < 2**(b-1)
def narrowrange(v,L,H):
if exact(v,L,H):
low[v] = max(low[v],L)
high[v] = min(high[v],H)
masked = {}
def narrowrange_rshift(v):
if operation[v][0] != '__rshift__': return
x,d = parents[v]
dist = low[d]
if dist != high[d]: return
L = low[x]//2**dist
H = high[x]//2**dist
narrowrange(v,L,H)
if exact(v,L,H):
if all(p in formula for p in parents[v]):
if (x,2**dist-1) in masked:
print('rem%d = (rem%d-rem%d)/2**%d' % (v,x,masked[x,2**dist-1],dist))
formula.add(v)
def narrowrange_lshift(v):
if operation[v][0] != '__lshift__': return
x,d = parents[v]
dist = low[d]
if dist != high[d]: return
L = low[x]*2**dist
H = high[x]*2**dist
narrowrange(v,L,H)
if exact(v,L,H):
if all(p in formula for p in parents[v]):
print('rem%d = rem%s*2**%d' % (v,parents[v][0],dist))
formula.add(v)
def narrowrange_and(v):
if operation[v][0] != '__and__': return
x,y = parents[v]
if low[y] != high[y]: return
y = low[y]
if y < 0: return
if y&(y+1): return
narrowrange(v,0,y)
if y != 0:
masked[x,y] = v
print('rem%d = quo%d' % (v,v))
formula.add(v)
def narrowrange_and_cond(v):
if operation[v][0] != '__and__': return
x,y = parents[v]
if (low[y],high[y]) != (-1,0):
y,x = parents[v]
if (low[y],high[y]) != (-1,0): return
narrowrange(v,min(low[x],0),max(high[x],0))
def narrowrange_or(v):
if operation[v][0] != '__or__': return
numparents = len(parents[v])
numzero = 0
for x in parents[v]:
if low[x] == 0 and high[x] == 0:
numzero += 1
if numzero == numparents:
narrowrange(v,0,0)
if numzero == numparents - 1:
for x in parents[v]:
if low[x] != 0 or high[x] != 0:
narrowrange(v,low[x],high[x])
if x in formula:
print('rem%d = rem%d' % (v,x))
formula.add(v)
# range of x+((x>>(b-1))&q) has two cases:
# x >= 0: range of x narrowed to [0,...]
# x < 0: range of x narrowed to [...,-1], plus q
def narrowrange_add_reduce(v):
if operation[v][0] != '__add__': return
if len(parents[v]) != 2: return
x,y = parents[v]
if operation[y][0] != '__and__':
y,x = parents[v]
if operation[y][0] != '__and__': return
s,q = parents[y]
if operation[s][0] != '__rshift__':
q,s = parents[y]
if operation[s][0] != '__rshift__': return
if parents[s][0] != x: return
d = parents[s][1]
b = bits[v]
if low[d] != b-1: return
if high[d] != b-1: return
if low[x] >= 0:
L,H = low[x],high[x]
elif high[x] < 0:
L,H = low[x]+low[q],high[x]+high[q]
else:
L1,H1 = 0,high[x]
L2,H2 = low[x]+low[q],-1+high[q]
L,H = min(L1,L2),max(H1,H2)
narrowrange(v,L,H)
if exact(v,L,H):
if x in formula:
if low[q] == high[q]:
if low[q] % modulus == 0:
print('rem%d = rem%d' % (v,x))
formula.add(v)
def narrowrange_add(v):
if operation[v][0] != '__add__': return
L,H = 0,0
for x in parents[v]:
L,H = L+low[x],H+high[x]
narrowrange(v,L,H)
if exact(v,L,H):
if all(p in formula for p in parents[v]):
print('rem%d = %s' % (v,'+'.join('rem%d'%p for p in parents[v])))
formula.add(v)
def narrowrange_constant(v):
if operation[v][0] != 'constant': return
b = bits[v]
if operation[v][2] >= 2**(b-1):
low[v] = operation[v][2] - 2**b
high[v] = operation[v][2] - 2**b
else:
low[v] = operation[v][2]
high[v] = operation[v][2]
def narrowrange_mul(v):
if operation[v][0] != '__mul__': return
L,H = 1,1
for x in parents[v]:
L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x])
narrowrange(v,L,H)
if exact(v,L,H):
if all(p in formula for p in parents[v]):
print('rem%d = %s' % (v,'*'.join('rem%d'%p for p in parents[v])))
formula.add(v)
def narrowrange_mulhrs16(v):
if operation[v][0] != 'mulhrs16': return
L,H = 1,1
for x in parents[v]:
L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x])
L = ((L//2**14)+1)//2
H = ((H//2**14)+1)//2
narrowrange(v,L,H)
def narrowrange_mulhi16(v):
if operation[v][0] != 'mulhi16': return
L,H = 1,1
for x in parents[v]:
L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x])
assert L >= (L//2**16)*2**16
L = L//2**16
H = H//2**16
narrowrange(v,L,H)
def narrowrange_sub(v):
if operation[v][0] != '__sub__': return
L = low[parents[v][0]]-high[parents[v][1]]
H = high[parents[v][0]]-low[parents[v][1]]
narrowrange(v,L,H)
if exact(v,L,H):
if all(p in formula for p in parents[v]):
print('rem%d = %s' % (v,'-'.join('rem%d'%p for p in parents[v])))
formula.add(v)
# -----
# assumptions: q odd, 0<q<2^15, r subtraction below is exact
# and qinv is reciprocal of q mod 2^16
# computation: r = mulhi(x,y)-mulhi(mullo(x,y,qinv),q)
# or: r = mulhi(x,y)-mulhi(mullo(x,yqinv),q)
# conclusion: xy = 2^16 r + dq where d is mullo(a,qinv)
def narrowrange_sub_hilohi(v,flipped):
if operation[v][0] != '__sub__': return
if flipped:
e,b = parents[v]
else:
b,e = parents[v]
if operation[b][0] != 'mulhi16': return
if operation[e][0] != 'mulhi16': return
x,y = parents[b]
d,q = parents[e]
if operation[q][0] != 'constant':
q,d = parents[e]
if operation[q][0] != 'constant': return
q = operation[q][2]
if not(q & 1): return
if q < 0: return
if q > 2**15: return
xylow = min(low[x]*low[y],low[x]*high[y],high[x]*low[y],high[x]*high[y])
xyhigh = max(low[x]*low[y],low[x]*high[y],high[x]*low[y],high[x]*high[y])
# x*y is between xylow and xyhigh
hixylow = xylow//2**16
hixyhigh = xyhigh//2**16
# mulhi(x,y) is between hixylow and hixyhigh
loqlow = -2**15*q
loqhigh = (2**15-1)*q
# mullo(...)*q is between loqlow and loqhigh
hiloqlow = loqlow//2**16
hiloqhigh = loqhigh//2**16
# mulhi(mullo(...),q) is between hiloqlow and hiloqhigh
rlow = hixylow-hiloqhigh
rhigh = hixyhigh-hiloqlow
# r is between rlow and rhigh
narrowrange(v,rlow,rhigh)
if len(parents[d]) not in (2,3): return
if len(parents[d]) == 2:
d1,d2 = parents[d]
if d1 != x:
d2,d1 = parents[d]
if d1 != x:
return
if operation[d2][:2] != ['constant',16]: return
yqinv = operation[d2][2]
if (yqinv*q-operation[y][2])%(2**16) != 0: return
if len(parents[d]) == 3:
d1,d2,d3 = parents[d]
if d1 != x:
d1,d2,d3 = d2,d3,d1
if d1 != x:
d1,d2,d3 = d2,d3,d1
if d1 != x: return
if d2 != y:
d2,d3 = d3,d2
if d2 != y: return
if operation[d3][:2] != ['constant',16]: return
qinv = operation[d3][2]
if (qinv*q)%(2**16) != 1: return
if exact(v,rlow,rhigh):
if q % modulus == 0:
if x not in formula:
print('# %s not in formula' % x)
if y not in formula:
print('# %s not in formula' % y)
if x in formula and y in formula:
if flipped:
print('rem%d = (((-1)*rem%d)*rem%d)/2**16' % (v,x,y))
else:
print('rem%d = (rem%d*rem%d)/2**16' % (v,x,y))
formula.add(v)
# XXX: should also do 2-input version with y*qinv precomputed
def sub16(a,b):
c = a-b
c %= 2**16
if c >= 2**15: c -= 2**16
return c
def mul16(a,b):
c = a*b
c %= 2**16
if c >= 2**15: c -= 2**16
return c
def mulhi16(a,b):
c = a*b
c //= 2**16
if c >= 2**15: c -= 2**16
return c
def mulhrs16(a,b):
c = a*b
c //= 2**14
c += 1
c //= 2
c %= 2**16
if c >= 2**15: c -= 2**16
return c
reduce2_cache = {}
reduce3_cache = {}
# x-mulhrs16(x,B)*C
# or mulhrs16(x,B)*C-x if flipped
def narrowrange_reduce2(v,flipped):
if operation[v][0] != '__sub__': return
if flipped:
y,x = parents[v]
else:
x,y = parents[v]
if operation[y][0] != '__mul__': return
z,C = parents[y]
if operation[z][0] != 'mulhrs16':
C,z = parents[y]
if operation[z][0] != 'mulhrs16':
return
if operation[C][:2] != ['constant',16]: return
x2,B = parents[z]
if x2 != x:
B,x2 = parents[z]
if x2 != x:
return
if operation[B][:2] != ['constant',16]: return
B = operation[B][2]
C = operation[C][2]
if (B,C) not in reduce2_cache:
L = None
H = None
modular = True
# XXX: can use low[x],high[x] with more work on caching
for i in range(-2**15,2**15):
if flipped:
j = sub16(mul16(mulhrs16(i,B),C),i)
if (j+i)%modulus: modular = False
else:
j = sub16(i,mul16(mulhrs16(i,B),C))
if (j-i)%modulus: modular = False
if L == None: L = j
if H == None: H = j
L = min(L,j)
H = max(H,j)
reduce2_cache[B,C] = L,H,modular
L,H,modular = reduce2_cache[B,C]
narrowrange(v,L,H)
if exact(v,L,H):
if modular:
if x not in formula:
print('# %s not in formula' % x)
if x in formula:
if flipped:
print('rem%d = (-1)*rem%d' % (v,x))
else:
print('rem%d = rem%d' % (v,x))
formula.add(v)
# x-mulhrs16(mulhi16(x,A),B)*C
# or mulhrs16(mulhi16(x,A),B)*C-x if flipped
def narrowrange_reduce3(v,flipped):
if operation[v][0] != '__sub__': return
if flipped:
y,x = parents[v]
else:
x,y = parents[v]
if operation[y][0] != '__mul__': return
z,C = parents[y]
if operation[z][0] != 'mulhrs16':
C,z = parents[y]
if operation[z][0] != 'mulhrs16':
return
if operation[C][:2] != ['constant',16]: return
u,B = parents[z]
if operation[u][0] != 'mulhi16':
B,u = parents[z]
if operation[u][0] != 'mulhi16':
return
if operation[B][:2] != ['constant',16]: return
x2,A = parents[u]
if x2 != x:
A,x2 = parents[u]
if x2 != x:
return
if operation[A][:2] != ['constant',16]: return
A = operation[A][2]
B = operation[B][2]
C = operation[C][2]
if (A,B,C) not in reduce3_cache:
L = None
H = None
modular = True
# XXX: can use low[x],high[x] with more work on caching
for i in range(-2**15,2**15):
if flipped:
j = sub16(mul16(mulhrs16(mulhi16(i,A),B),C),i)
if (j+i)%modulus: modular = False
else:
j = sub16(i,mul16(mulhrs16(mulhi16(i,A),B),C))
if (j-i)%modulus: modular = False
if L == None: L = j
if H == None: H = j
L = min(L,j)
H = max(H,j)
reduce3_cache[A,B,C] = L,H,modular
L,H,modular = reduce3_cache[A,B,C]
narrowrange(v,L,H)
if exact(v,L,H):
if modular:
if x not in formula:
print('# %s not in formula' % x)
if x in formula:
if flipped:
print('rem%d = (-1)*rem%d' % (v,x))
else:
print('rem%d = rem%d' % (v,x))
formula.add(v)
def computerange(v):
b = bits[v]
low[v] = -2**(b-1)
high[v] = 2**(b-1)-1
narrowrange_constant(v)
narrowrange_mul(v)
narrowrange_mulhi16(v)
narrowrange_mulhrs16(v)
narrowrange_add_reduce(v)
narrowrange_add(v)
narrowrange_sub_hilohi(v,True)
narrowrange_sub_hilohi(v,False)
narrowrange_reduce2(v,True)
narrowrange_reduce2(v,False)
narrowrange_reduce3(v,True)
narrowrange_reduce3(v,False)
narrowrange_sub(v)
narrowrange_lshift(v)
narrowrange_rshift(v)
narrowrange_and_cond(v)
narrowrange_and(v)
narrowrange_or(v)
for p in program:
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'copy':
if p[2] not in value:
input(p[2])
v = nextvalue
print('v%d = %s' % (v,operation[v][1]))
print('rem%d = %s' % (v,operation[v][1]))
formula.add(v)
print('assertsignedminmax(v%d,%d,%d)' % (v,low[v],high[v]))
value[p[1]] = value[p[2]]
if p[1].startswith('out'):
print('%s = v%d' % (p[1],value[p[1]]))
if value[p[1]] in formula:
print('rem%s = rem%d' % (p[1],value[p[1]]))
continue
nextvalue += 1
operation[nextvalue] = [p[0]]
if p[0] == 'constant':
parents[nextvalue] = []
operation[nextvalue] += [int(p[2]),int(p[3])]
bits[nextvalue] = int(p[2])
elif p[0] in ['__sub__','__rshift__','__lshift__','mulhi16','mulhrs16']:
# binary size-preserving operation
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] in ['__or__','__and__','__add__','__mul__']:
b = bits[value[p[2]]]
assert all(b == bits[value[v]] for v in p[2:])
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = b
elif p[0] == 'Concat':
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = sum(bits[v] for v in parents[nextvalue])
elif p[0] == 'Extract':
top = int(p[3])
bot = int(p[4])
assert top >= bot
assert bits[value[p[2]]] > top
assert bot >= 0
operation[nextvalue] += [top,bot]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = top + 1 - bot
elif p[0] in ('SignExt','ZeroExt'):
morebits = int(p[3])
operation[nextvalue] += [morebits]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = bits[value[p[2]]] + morebits
else:
raise Exception('unknown internal operation %s' % p[0])
value[p[1]] = nextvalue
print('# assigning v%d to source %s' % (nextvalue,p[1]))
v = nextvalue
out = ['v%s' % x for x in parents[v]]
out += ['%s' % x for x in operation[v][1:]]
print('v%d = %s(%s)' % (v,operation[v][0],','.join(out)))
computerange(nextvalue)
print('assertsignedminmax(v%d,%d,%d)' % (v,low[v],high[v]))
if low[v] == high[v]:
if not v in formula:
print('rem%d = %d' % (v,low[v]))
formula.add(v)