-rwxr-xr-x 15821 nttcompiler-20220411/scripts/opt2range
#!/usr/bin/env python3 import sys import math from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums modulus = int(sys.argv[1]) inputlow = int(sys.argv[2]) inputhigh = int(sys.argv[3]) def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + Literal('constant').suppress() + lparen + number + comma + number + rparen ).setParseAction(group('constant')) for binary in ['__sub__','__rshift__','__lshift__','mulhi16','mulhrs16']: assignment |= ( name + equal + Literal(binary).suppress() + lparen + name + comma + name + rparen ).setParseAction(group(binary)) assignment |= ( name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignment |= ( name + equal + Literal('SignExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('SignExt')) assignment |= ( name + equal + Literal('ZeroExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('ZeroExt')) for manyary in ['Concat','__or__','__and__','__add__','__mul__']: assignment |= ( name + equal + Literal(manyary).suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group(manyary)) assignment |= ( name + equal + name).setParseAction(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() program = '' for line in sys.stdin: if line.startswith('#'): continue program += line+'\n' program = assignments.parseString(program) program = list(program) nextvalue = 0 # indexed by variable name: value = {} # indexed by value: operation = {} parents = {} bits = {} low = {} high = {} formula = set() def input(v): global nextvalue y = v.split('_') nextvalue += 1 value[v] = nextvalue operation[nextvalue] = ['input',v] parents[nextvalue] = [] bits[nextvalue] = int(y[-1]) low[nextvalue] = inputlow high[nextvalue] = inputhigh def exact(v,L,H): b = bits[v] assert L <= H return L >= -2**(b-1) and H < 2**(b-1) def narrowrange(v,L,H): if exact(v,L,H): low[v] = max(low[v],L) high[v] = min(high[v],H) masked = {} def narrowrange_rshift(v): if operation[v][0] != '__rshift__': return x,d = parents[v] dist = low[d] if dist != high[d]: return L = low[x]//2**dist H = high[x]//2**dist narrowrange(v,L,H) if exact(v,L,H): if all(p in formula for p in parents[v]): if (x,2**dist-1) in masked: print('rem%d = (rem%d-rem%d)/2**%d' % (v,x,masked[x,2**dist-1],dist)) formula.add(v) def narrowrange_lshift(v): if operation[v][0] != '__lshift__': return x,d = parents[v] dist = low[d] if dist != high[d]: return L = low[x]*2**dist H = high[x]*2**dist narrowrange(v,L,H) if exact(v,L,H): if all(p in formula for p in parents[v]): print('rem%d = rem%s*2**%d' % (v,parents[v][0],dist)) formula.add(v) def narrowrange_and(v): if operation[v][0] != '__and__': return x,y = parents[v] if low[y] != high[y]: return y = low[y] if y < 0: return if y&(y+1): return narrowrange(v,0,y) if y != 0: masked[x,y] = v print('rem%d = quo%d' % (v,v)) formula.add(v) def narrowrange_and_cond(v): if operation[v][0] != '__and__': return x,y = parents[v] if (low[y],high[y]) != (-1,0): y,x = parents[v] if (low[y],high[y]) != (-1,0): return narrowrange(v,min(low[x],0),max(high[x],0)) def narrowrange_or(v): if operation[v][0] != '__or__': return numparents = len(parents[v]) numzero = 0 for x in parents[v]: if low[x] == 0 and high[x] == 0: numzero += 1 if numzero == numparents: narrowrange(v,0,0) if numzero == numparents - 1: for x in parents[v]: if low[x] != 0 or high[x] != 0: narrowrange(v,low[x],high[x]) if x in formula: print('rem%d = rem%d' % (v,x)) formula.add(v) # range of x+((x>>(b-1))&q) has two cases: # x >= 0: range of x narrowed to [0,...] # x < 0: range of x narrowed to [...,-1], plus q def narrowrange_add_reduce(v): if operation[v][0] != '__add__': return if len(parents[v]) != 2: return x,y = parents[v] if operation[y][0] != '__and__': y,x = parents[v] if operation[y][0] != '__and__': return s,q = parents[y] if operation[s][0] != '__rshift__': q,s = parents[y] if operation[s][0] != '__rshift__': return if parents[s][0] != x: return d = parents[s][1] b = bits[v] if low[d] != b-1: return if high[d] != b-1: return if low[x] >= 0: L,H = low[x],high[x] elif high[x] < 0: L,H = low[x]+low[q],high[x]+high[q] else: L1,H1 = 0,high[x] L2,H2 = low[x]+low[q],-1+high[q] L,H = min(L1,L2),max(H1,H2) narrowrange(v,L,H) if exact(v,L,H): if x in formula: if low[q] == high[q]: if low[q] % modulus == 0: print('rem%d = rem%d' % (v,x)) formula.add(v) def narrowrange_add(v): if operation[v][0] != '__add__': return L,H = 0,0 for x in parents[v]: L,H = L+low[x],H+high[x] narrowrange(v,L,H) if exact(v,L,H): if all(p in formula for p in parents[v]): print('rem%d = %s' % (v,'+'.join('rem%d'%p for p in parents[v]))) formula.add(v) def narrowrange_constant(v): if operation[v][0] != 'constant': return b = bits[v] if operation[v][2] >= 2**(b-1): low[v] = operation[v][2] - 2**b high[v] = operation[v][2] - 2**b else: low[v] = operation[v][2] high[v] = operation[v][2] def narrowrange_mul(v): if operation[v][0] != '__mul__': return L,H = 1,1 for x in parents[v]: L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x]) narrowrange(v,L,H) if exact(v,L,H): if all(p in formula for p in parents[v]): print('rem%d = %s' % (v,'*'.join('rem%d'%p for p in parents[v]))) formula.add(v) def narrowrange_mulhrs16(v): if operation[v][0] != 'mulhrs16': return L,H = 1,1 for x in parents[v]: L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x]) L = ((L//2**14)+1)//2 H = ((H//2**14)+1)//2 narrowrange(v,L,H) def narrowrange_mulhi16(v): if operation[v][0] != 'mulhi16': return L,H = 1,1 for x in parents[v]: L,H = min(L*low[x],L*high[x],H*low[x],H*high[x]),max(L*low[x],L*high[x],H*low[x],H*high[x]) assert L >= (L//2**16)*2**16 L = L//2**16 H = H//2**16 narrowrange(v,L,H) def narrowrange_sub(v): if operation[v][0] != '__sub__': return L = low[parents[v][0]]-high[parents[v][1]] H = high[parents[v][0]]-low[parents[v][1]] narrowrange(v,L,H) if exact(v,L,H): if all(p in formula for p in parents[v]): print('rem%d = %s' % (v,'-'.join('rem%d'%p for p in parents[v]))) formula.add(v) # ----- # assumptions: q odd, 0<q<2^15, r subtraction below is exact # and qinv is reciprocal of q mod 2^16 # computation: r = mulhi(x,y)-mulhi(mullo(x,y,qinv),q) # or: r = mulhi(x,y)-mulhi(mullo(x,yqinv),q) # conclusion: xy = 2^16 r + dq where d is mullo(a,qinv) def narrowrange_sub_hilohi(v,flipped): if operation[v][0] != '__sub__': return if flipped: e,b = parents[v] else: b,e = parents[v] if operation[b][0] != 'mulhi16': return if operation[e][0] != 'mulhi16': return x,y = parents[b] d,q = parents[e] if operation[q][0] != 'constant': q,d = parents[e] if operation[q][0] != 'constant': return q = operation[q][2] if not(q & 1): return if q < 0: return if q > 2**15: return xylow = min(low[x]*low[y],low[x]*high[y],high[x]*low[y],high[x]*high[y]) xyhigh = max(low[x]*low[y],low[x]*high[y],high[x]*low[y],high[x]*high[y]) # x*y is between xylow and xyhigh hixylow = xylow//2**16 hixyhigh = xyhigh//2**16 # mulhi(x,y) is between hixylow and hixyhigh loqlow = -2**15*q loqhigh = (2**15-1)*q # mullo(...)*q is between loqlow and loqhigh hiloqlow = loqlow//2**16 hiloqhigh = loqhigh//2**16 # mulhi(mullo(...),q) is between hiloqlow and hiloqhigh rlow = hixylow-hiloqhigh rhigh = hixyhigh-hiloqlow # r is between rlow and rhigh narrowrange(v,rlow,rhigh) if len(parents[d]) not in (2,3): return if len(parents[d]) == 2: d1,d2 = parents[d] if d1 != x: d2,d1 = parents[d] if d1 != x: return if operation[d2][:2] != ['constant',16]: return yqinv = operation[d2][2] if (yqinv*q-operation[y][2])%(2**16) != 0: return if len(parents[d]) == 3: d1,d2,d3 = parents[d] if d1 != x: d1,d2,d3 = d2,d3,d1 if d1 != x: d1,d2,d3 = d2,d3,d1 if d1 != x: return if d2 != y: d2,d3 = d3,d2 if d2 != y: return if operation[d3][:2] != ['constant',16]: return qinv = operation[d3][2] if (qinv*q)%(2**16) != 1: return if exact(v,rlow,rhigh): if q % modulus == 0: if x not in formula: print('# %s not in formula' % x) if y not in formula: print('# %s not in formula' % y) if x in formula and y in formula: if flipped: print('rem%d = (((-1)*rem%d)*rem%d)/2**16' % (v,x,y)) else: print('rem%d = (rem%d*rem%d)/2**16' % (v,x,y)) formula.add(v) # XXX: should also do 2-input version with y*qinv precomputed def sub16(a,b): c = a-b c %= 2**16 if c >= 2**15: c -= 2**16 return c def mul16(a,b): c = a*b c %= 2**16 if c >= 2**15: c -= 2**16 return c def mulhi16(a,b): c = a*b c //= 2**16 if c >= 2**15: c -= 2**16 return c def mulhrs16(a,b): c = a*b c //= 2**14 c += 1 c //= 2 c %= 2**16 if c >= 2**15: c -= 2**16 return c reduce2_cache = {} reduce3_cache = {} # x-mulhrs16(x,B)*C # or mulhrs16(x,B)*C-x if flipped def narrowrange_reduce2(v,flipped): if operation[v][0] != '__sub__': return if flipped: y,x = parents[v] else: x,y = parents[v] if operation[y][0] != '__mul__': return z,C = parents[y] if operation[z][0] != 'mulhrs16': C,z = parents[y] if operation[z][0] != 'mulhrs16': return if operation[C][:2] != ['constant',16]: return x2,B = parents[z] if x2 != x: B,x2 = parents[z] if x2 != x: return if operation[B][:2] != ['constant',16]: return B = operation[B][2] C = operation[C][2] if (B,C) not in reduce2_cache: L = None H = None modular = True # XXX: can use low[x],high[x] with more work on caching for i in range(-2**15,2**15): if flipped: j = sub16(mul16(mulhrs16(i,B),C),i) if (j+i)%modulus: modular = False else: j = sub16(i,mul16(mulhrs16(i,B),C)) if (j-i)%modulus: modular = False if L == None: L = j if H == None: H = j L = min(L,j) H = max(H,j) reduce2_cache[B,C] = L,H,modular L,H,modular = reduce2_cache[B,C] narrowrange(v,L,H) if exact(v,L,H): if modular: if x not in formula: print('# %s not in formula' % x) if x in formula: if flipped: print('rem%d = (-1)*rem%d' % (v,x)) else: print('rem%d = rem%d' % (v,x)) formula.add(v) # x-mulhrs16(mulhi16(x,A),B)*C # or mulhrs16(mulhi16(x,A),B)*C-x if flipped def narrowrange_reduce3(v,flipped): if operation[v][0] != '__sub__': return if flipped: y,x = parents[v] else: x,y = parents[v] if operation[y][0] != '__mul__': return z,C = parents[y] if operation[z][0] != 'mulhrs16': C,z = parents[y] if operation[z][0] != 'mulhrs16': return if operation[C][:2] != ['constant',16]: return u,B = parents[z] if operation[u][0] != 'mulhi16': B,u = parents[z] if operation[u][0] != 'mulhi16': return if operation[B][:2] != ['constant',16]: return x2,A = parents[u] if x2 != x: A,x2 = parents[u] if x2 != x: return if operation[A][:2] != ['constant',16]: return A = operation[A][2] B = operation[B][2] C = operation[C][2] if (A,B,C) not in reduce3_cache: L = None H = None modular = True # XXX: can use low[x],high[x] with more work on caching for i in range(-2**15,2**15): if flipped: j = sub16(mul16(mulhrs16(mulhi16(i,A),B),C),i) if (j+i)%modulus: modular = False else: j = sub16(i,mul16(mulhrs16(mulhi16(i,A),B),C)) if (j-i)%modulus: modular = False if L == None: L = j if H == None: H = j L = min(L,j) H = max(H,j) reduce3_cache[A,B,C] = L,H,modular L,H,modular = reduce3_cache[A,B,C] narrowrange(v,L,H) if exact(v,L,H): if modular: if x not in formula: print('# %s not in formula' % x) if x in formula: if flipped: print('rem%d = (-1)*rem%d' % (v,x)) else: print('rem%d = rem%d' % (v,x)) formula.add(v) def computerange(v): b = bits[v] low[v] = -2**(b-1) high[v] = 2**(b-1)-1 narrowrange_constant(v) narrowrange_mul(v) narrowrange_mulhi16(v) narrowrange_mulhrs16(v) narrowrange_add_reduce(v) narrowrange_add(v) narrowrange_sub_hilohi(v,True) narrowrange_sub_hilohi(v,False) narrowrange_reduce2(v,True) narrowrange_reduce2(v,False) narrowrange_reduce3(v,True) narrowrange_reduce3(v,False) narrowrange_sub(v) narrowrange_lshift(v) narrowrange_rshift(v) narrowrange_and_cond(v) narrowrange_and(v) narrowrange_or(v) for p in program: if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'copy': if p[2] not in value: input(p[2]) v = nextvalue print('v%d = %s' % (v,operation[v][1])) print('rem%d = %s' % (v,operation[v][1])) formula.add(v) print('assertsignedminmax(v%d,%d,%d)' % (v,low[v],high[v])) value[p[1]] = value[p[2]] if p[1].startswith('out'): print('%s = v%d' % (p[1],value[p[1]])) if value[p[1]] in formula: print('rem%s = rem%d' % (p[1],value[p[1]])) continue nextvalue += 1 operation[nextvalue] = [p[0]] if p[0] == 'constant': parents[nextvalue] = [] operation[nextvalue] += [int(p[2]),int(p[3])] bits[nextvalue] = int(p[2]) elif p[0] in ['__sub__','__rshift__','__lshift__','mulhi16','mulhrs16']: # binary size-preserving operation assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] in ['__or__','__and__','__add__','__mul__']: b = bits[value[p[2]]] assert all(b == bits[value[v]] for v in p[2:]) parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = b elif p[0] == 'Concat': parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = sum(bits[v] for v in parents[nextvalue]) elif p[0] == 'Extract': top = int(p[3]) bot = int(p[4]) assert top >= bot assert bits[value[p[2]]] > top assert bot >= 0 operation[nextvalue] += [top,bot] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = top + 1 - bot elif p[0] in ('SignExt','ZeroExt'): morebits = int(p[3]) operation[nextvalue] += [morebits] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = bits[value[p[2]]] + morebits else: raise Exception('unknown internal operation %s' % p[0]) value[p[1]] = nextvalue print('# assigning v%d to source %s' % (nextvalue,p[1])) v = nextvalue out = ['v%s' % x for x in parents[v]] out += ['%s' % x for x in operation[v][1:]] print('v%d = %s(%s)' % (v,operation[v][0],','.join(out))) computerange(nextvalue) print('assertsignedminmax(v%d,%d,%d)' % (v,low[v],high[v])) if low[v] == high[v]: if not v in formula: print('rem%d = %d' % (v,low[v])) formula.add(v)