-rwxr-xr-x 19249 nttcompiler-20220411/scripts/unroll2opt
#!/usr/bin/env python3 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + Literal('constant').suppress() + lparen + number + comma + number + rparen ).setParseAction(group('constant')) for binary in ['__sub__','__rshift__','__lshift__','LShR']: assignment |= ( name + equal + Literal(binary).suppress() + lparen + name + comma + name + rparen ).setParseAction(group(binary)) assignment |= ( name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignment |= ( name + equal + Literal('SignExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('SignExt')) assignment |= ( name + equal + Literal('ZeroExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('ZeroExt')) for manyary in ['Concat','__or__','__and__','__add__','__mul__']: assignment |= ( name + equal + Literal(manyary).suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group(manyary)) assignment |= ( name + equal + name).setParseAction(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() program = sys.stdin.read() program = assignments.parseString(program) program = list(program) nextvalue = 0 # indexed by variable name: value = {} # indexed by value: operation = {} parents = {} bits = {} def input(v): global nextvalue y = v.split('_') nextvalue += 1 value[v] = nextvalue operation[nextvalue] = ['input',v] parents[nextvalue] = [] bits[nextvalue] = int(y[-1]) for p in program: if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'copy': if p[2] not in value: input(p[2]) value[p[1]] = value[p[2]] continue nextvalue += 1 operation[nextvalue] = [p[0]] if p[0] == 'constant': parents[nextvalue] = [] operation[nextvalue] += [int(p[2]),int(p[3])] bits[nextvalue] = int(p[2]) elif p[0] in ['__sub__','__rshift__','__lshift__','LShR']: # binary size-preserving operation assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] in ['__or__','__and__','__add__','__mul__']: b = bits[value[p[2]]] assert all(b == bits[value[v]] for v in p[2:]) parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = b elif p[0] == 'Concat': parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = sum(bits[v] for v in parents[nextvalue]) elif p[0] == 'Extract': top = int(p[3]) bot = int(p[4]) assert top >= bot assert bits[value[p[2]]] > top assert bot >= 0 operation[nextvalue] += [top,bot] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = top + 1 - bot elif p[0] in ('SignExt','ZeroExt'): morebits = int(p[3]) operation[nextvalue] += [morebits] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = bits[value[p[2]]] + morebits else: raise Exception('unknown internal operation %s' % p[0]) value[p[1]] = nextvalue optloop = 0 progress = True while progress: progress = False print('# opt loop %d' % optloop) optloop += 1 constants = {} for v in operation: if operation[v][0] != 'constant': continue key = (operation[v][1],operation[v][2]) constants[key] = v # for any __mul__(mulhrs16(mulhi16(x,A),B),C) # want to have x minus that differences = set() for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences.add((parents[v][0],parents[v][1])) hihrslo = [] for v in operation: if operation[v][0] != '__mul__': continue c,C = parents[v] if operation[C][0] != 'constant': C,c = parents[v] if operation[C][0] != 'constant': continue if operation[c][0] != 'mulhrs16': continue b,B = parents[c] if operation[B][0] != 'constant': B,b = parents[c] if operation[B][0] != 'constant': continue if operation[b][0] != 'mulhi16': continue a,A = parents[b] if operation[A][0] != 'constant': A,a = parents[b] if operation[A][0] != 'constant': continue if (a,v) not in differences: hihrslo += [(a,v)] for a,v in hihrslo: nextvalue += 1 operation[nextvalue] = ['__sub__'] parents[nextvalue] = [a,v] bits[nextvalue] = bits[v] # print('# providing v%d-v%d' % (a,v)) # for any __mul__(mulhrs16(x,A),B) # want to have x minus that # ... but limit A to constants <=4 to save time differences = set() for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences.add((parents[v][0],parents[v][1])) hihrslo = [] for v in operation: if operation[v][0] != '__mul__': continue b,B = parents[v] if operation[B][0] != 'constant': B,b = parents[v] if operation[B][0] != 'constant': continue if operation[B][2] > 4: continue if operation[b][0] != 'mulhi16': continue a,A = parents[b] if operation[A][0] != 'constant': A,a = parents[b] if operation[A][0] != 'constant': continue if (a,v) not in differences: hihrslo += [(a,v)] for a,v in hihrslo: nextvalue += 1 operation[nextvalue] = ['__sub__'] parents[nextvalue] = [a,v] bits[nextvalue] = bits[v] # print('# providing v%d-v%d' % (a,v)) # given a-b, try ...+a+(c-b) -> ...+c+(a-b) # (but do not mark as progress) differences = {} for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences[parents[v][0],parents[v][1]] = v differenceflips = [] for v in operation: if operation[v][0] != '__add__': continue todo = None for j in range(len(parents[v])): for i in range(j): A,D = parents[v][i],parents[v][j] for a,d in (A,D),(D,A): if operation[d][0] != '__sub__': continue c,b = parents[d] if (a,b) not in differences: continue todo = (v,i,j,c,differences[a,b]) if todo == None: continue differenceflips += [todo] for v,i,j,c,ab in differenceflips: parents[v][i] = c parents[v][j] = ab # given a-b, do ...+b+(x-a) -> (...+x)-(a-b) differences = {} for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences[parents[v][0],parents[v][1]] = v differenceflips = [] for v in operation: if operation[v][0] != '__add__': continue todo = None for j in range(len(parents[v])): for i in range(j): B,D = parents[v][i],parents[v][j] for b,d in (B,D),(D,B): if operation[d][0] != '__sub__': continue x,a = parents[d] if (a,b) not in differences: continue todo = (v,i,j,x,differences[a,b]) if todo == None: continue differenceflips += [todo] for v,i,j,x,ab in differenceflips: if len(parents[v]) == 2: plusx = x else: nextvalue += 1 operation[nextvalue] = ['__add__'] parents[nextvalue] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j] parents[nextvalue] += [x] bits[nextvalue] = bits[v] plusx = nextvalue operation[v] = ['__sub__'] parents[v] = [plusx,ab] progress = True # given a-b, do (c+b)-a -> c-(a-b) differences = {} for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences[parents[v][0],parents[v][1]] = v cablist = [] for v in operation: if operation[v][0] != '__sub__': continue s,a = parents[v] if operation[s][0] != '__add__': continue if len(parents[s]) != 2: continue C,B = parents[s] todo = None for b,c in (B,C),(C,B): if (a,b) in differences: todo = (v,c,differences[a,b]) if todo == None: continue cablist += [todo] for v,c,ab in cablist: parents[v] = (c,ab) progress = True if 0: # a-b is high priority for simplifying a+b+... differences = set() for v in operation: if operation[v][0] != '__sub__': continue if len(parents[v]) != 2: continue differences.add((parents[v][0],parents[v][1])) differencesums = {} for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) != 2: continue if (parents[v][0],parents[v][1]) not in differences: continue differencesums[parents[v][0],parents[v][1]] = v differencesums[parents[v][1],parents[v][0]] = v simplifysums = [] for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) <= 2: continue todo = None for j in range(len(parents[v])): for i in range(j): if (parents[v][i],parents[v][j]) in differences: todo = (v,i,j) if todo == None: continue simplifysums += [todo] for v,i,j in simplifysums: if (parents[v][i],parents[v][j]) not in differencesums: nextvalue += 1 operation[nextvalue] = ['__add__'] parents[nextvalue] = [parents[v][i],parents[v][j]] bits[nextvalue] = bits[v] differencesums[parents[v][i],parents[v][j]] = nextvalue differencesums[parents[v][j],parents[v][i]] = nextvalue u = differencesums[parents[v][i],parents[v][j]] parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j] parents[v] += [u] progress = True if progress: continue # use a+b to simplify a+b+c+... pairsums = {} for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) != 2: continue pairsums[parents[v][0],parents[v][1]] = v pairsums[parents[v][1],parents[v][0]] = v simplifysums = [] for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) <= 2: continue s = None for j in range(len(parents[v])): for i in range(j): if (parents[v][i],parents[v][j]) in pairsums: s = (v,i,j,pairsums[parents[v][i],parents[v][j]]) if s == None: continue simplifysums += [s] for v,i,j,u in simplifysums: parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j] parents[v] += [u] progress = True # (a+b+c)-d, (a+b+d)-c -> (a+b)+(c-d),(a+b)-(c-d) abcdpatterns = {} abcdrewrite = [] for v in operation: if operation[v][0] != '__sub__': continue s,d = parents[v] if operation[s][0] != '__add__': continue if len(parents[s]) != 3: continue A,B,C = parents[s] todo = None for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A): if (a,b) not in abcdpatterns: abcdpatterns[a,b] = [] for u,uc,ud in abcdpatterns[a,b]: if (uc,ud) == (d,c): todo = (v,u,a,b,c,d) if todo == None: for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A): abcdpatterns[a,b] += [(v,c,d)] continue abcdrewrite += [todo] abcdrewritten = set() for v,u,a,b,c,d in abcdrewrite: if v in abcdrewritten: continue if u in abcdrewritten: continue nextvalue += 1 operation[nextvalue] = ['__add__'] parents[nextvalue] = [a,b] bits[nextvalue] = bits[v] ab = nextvalue nextvalue += 1 operation[nextvalue] = ['__sub__'] parents[nextvalue] = [c,d] bits[nextvalue] = bits[v] cd = nextvalue operation[v] = ['__add__'] parents[v] = [ab,cd] abcdrewritten.add(v) operation[u] = ['__sub__'] parents[u] = [ab,cd] abcdrewritten.add(u) progress = True # a+b+c, a+b+d -> (a+b)+c, (a+b)+d abcpatterns = {} abcrewrite = [] for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) != 3: continue A,B,C = parents[v] todo = None for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A): if (a,b) not in abcpatterns: abcpatterns[a,b] = [] for u,d in abcpatterns[a,b]: todo = (v,u,a,b,c,d) if todo == None: for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A): abcpatterns[a,b] += [(v,c)] continue abcrewrite += [todo] abcrewritten = set() for v,u,a,b,c,d in abcrewrite: if v in abcrewritten: continue if u in abcrewritten: continue nextvalue += 1 operation[nextvalue] = ['__add__'] parents[nextvalue] = [a,b] bits[nextvalue] = bits[v] ab = nextvalue operation[v] = ['__add__'] parents[v] = [ab,c] abcrewritten.add(v) operation[u] = ['__add__'] parents[u] = [ab,d] abcrewritten.add(u) progress = True if 1: # (a-b)+mulhi16(...) -> a+(mulhi16(...)-b) subaddmul = [] for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) != 2: continue c,m = parents[v] if operation[m][0] != 'mulhi16': m,c = parents[v] if operation[m][0] != 'mulhi16': continue if operation[c][0] != '__sub__': continue a,b = parents[c] subaddmul += [(v,a,b,m)] for v,a,b,m in subaddmul: nextvalue += 1 operation[nextvalue] = ['__sub__'] parents[nextvalue] = [m,b] bits[nextvalue] = bits[v] operation[v] = ['__add__'] parents[v] = [a,nextvalue] progress = True # extract x+((x>>(b-1))&u) from sums reduce = [] for v in operation: if operation[v][0] != '__add__': continue if len(parents[v]) < 3: continue b = bits[v] vreduce = False for i in range(len(parents[v])): if vreduce: continue y = parents[v][i] if operation[y][0] != '__and__': continue t,u = parents[y] if operation[t][0] != '__rshift__': u,t = parents[y] if operation[t][0] != '__rshift__': continue x,d = parents[t] if operation[d] != ['constant',b,b-1]: continue for j in range(len(parents[v])): if vreduce: continue if parents[v][j] != x: continue reduce += [(v,x,y)] vreduce = True for v,x,y in reduce: i = parents[v].index(y) j = parents[v].index(x) assert i != j nextvalue += 1 operation[nextvalue] = ['__add__'] parents[nextvalue] = [x,y] bits[nextvalue] = bits[v] parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j] parents[v] += [nextvalue] # (SignExt(x,16)*constant(32,y))[31:16] -> mulhi16(x,constant(16,y)) # if x has 16 bits and -2^15 <= y < 2^15 mulhi16 = [] for v in operation: if operation[v] != ['Extract',31,16]: continue hi = parents[v][0] if operation[hi][0] != '__mul__': continue s,c = parents[hi] if operation[s] != ['SignExt',16]: c,s = parents[hi] if operation[s] != ['SignExt',16]: continue if operation[c][0] != 'constant': continue if operation[c][1] != 32: continue y = operation[c][2] if y > 2**31: y -= 2**32 if y < -2**15: continue if y >= 2**15: continue y %= 2**16 x = parents[s][0] if bits[x] != 16: continue mulhi16 += [(v,x,y)] for v,x,y in mulhi16: if (16,y) in constants: c = constants[16,y] else: nextvalue += 1 operation[nextvalue] = ['constant',16,y] parents[nextvalue] = [] bits[nextvalue] = 16 c = nextvalue operation[v] = ['mulhi16'] parents[v] = [x,c] progress = True # Extract(LShR(LShR(__rshift__(Concat(x,x),16)*y,14)+1,1),15,0) # -> mulhrs16(x,y) # if x has 16 bits and -2^15 <= y < 2^15 mulhrs16 = [] for v in operation: if operation[v] != ['Extract',15,0]: continue a = parents[v][0] if operation[a][0] != 'LShR': continue b,c = parents[a] if operation[c] != ['constant',32,1]: continue if operation[b][0] != '__add__': continue d,e = parents[b] if operation[d][0] != 'LShR': e,d = parents[b] if operation[d][0] != 'LShR': continue if operation[e] != ['constant',32,1]: continue f,g = parents[d] if operation[g] != ['constant',32,14]: continue if operation[f][0] != '__mul__': continue h,i = parents[f] if operation[h][0] != '__rshift__': i,h = parents[f] if operation[h][0] != '__rshift__': continue if operation[i][:2] != ['constant',32]: continue y = operation[i][2] if y > 2**31: y -= 2**32 if y < -2**15: continue if y >= 2**15: continue y %= 2**16 j,k = parents[h] if operation[k] != ['constant',32,16]: continue if operation[j][0] != 'Concat': continue if len(parents[j]) != 2: continue x = parents[j][0] if x != parents[j][1]: continue if bits[x] != 16: continue mulhrs16 += [(v,x,y)] for v,x,y in mulhrs16: if (16,y) in constants: c = constants[16,y] else: nextvalue += 1 operation[nextvalue] = ['constant',16,y] parents[nextvalue] = [] bits[nextvalue] = 16 c = nextvalue operation[v] = ['mulhrs16'] parents[v] = [x,c] progress = True # ----- clean up unused nodes, merge children = dict() for z in operation: children[z] = set() for v in value: if v.startswith('out'): children[value[v]].add(-1) for z in operation: for x in parents[z]: children[x].add(z) deleting = set(v for v in operation if len(children[v]) == 0) merging = deleting.copy() merge = [] for x in operation: c = list(children[x]) for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]: if y == -1: continue if z == -1: continue if operation[y] != operation[z]: continue parentsmatch = False if parents[y] == parents[z]: parentsmatch = True if operation[y][0] in ['signedmin','signedmax']: if set(parents[y]) == set(parents[z]): parentsmatch = True if not parentsmatch: continue assert bits[y] == bits[z] if y in merging: continue if z in merging: continue merge += [(y,z)] merging.add(y) merging.add(z) for y,z in merge: # eliminate z in favor of y for t in children[z]: if t == -1: for v in value: if v.startswith('out'): if value[v] == z: value[v] = y else: for j in range(len(parents[t])): if parents[t][j] == z: parents[t][j] = y deleting.add(z) for v in deleting: del operation[v] del parents[v] del bits[v] done = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print('v%d = %s' % (v,operation[v][1])) else: p = ['v%s' % x for x in parents[v]] p += ['%s' % x for x in operation[v][1:]] print('v%d = %s(%s)' % (v,operation[v][0],','.join(p))) for v in value: if not v.startswith('out'): continue do(value[v]) print('%s = v%d' % (v,value[v]))