-rwxr-xr-x 5118 nttcompiler-20220411/scripts/range2linear
#!/usr/bin/env python3 import sys import math from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums q = int(sys.argv[1]) def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() equal = Literal('=').suppress() caret = Literal('^').suppress() | Literal('**').suppress() newline = Literal('_NEWLINE').suppress() timesdiv = Literal('*') | Literal('/') pm = Literal('+') | Literal('-') number = (Word(nums)|Literal('-')+Word(nums)).setParseAction(lambda x:[['number'] + list(x)]) name = Word(alphas,alphas+nums+"_").setParseAction(lambda x:[['name'] + list(x)]) expr = Forward() atom = lparen + expr + rparen | name | number power = (atom + Optional(caret + number)).setParseAction(group('power')) term = (power + ZeroOrMore(timesdiv + power)).setParseAction(group('term')) sum = (term + ZeroOrMore(pm + term)).setParseAction(group('sum')) expr << sum.setParseAction(group('expr')) assignment = (name + equal + expr + newline).setParseAction(group('assignment')) file_input = ZeroOrMore(newline | assignment) + StringEnd() program = '' while True: line = sys.stdin.readline() if not line: break if not line.startswith('rem'): continue program += line + ' _NEWLINE ' program = file_input.parseString(program) program = list(program) variables = {} # e.g., variables['rem10'] = [(3,'in1'),(4,'in2')] # meaning that rem10 is 3*in1 + 4*in2 modulo q # must minimize: [(3,'in5'),(4,'in5')] not allowed # exception: [(7,)] means 7 modulo q # XXX: slightly cleaner approach would be [(7,'')] # XXX: currently no support for, e.g., 7+3*in1 def exprtimes(e,f): if e == [] and f == []: return [] if len(e) == 1 and len(e[0]) == 1: e,f = f,e assert len(f) == 1 assert len(f[0]) == 1 x = f[0][0] if len(e) == 1 and len(e[0]) == 1: return [((x*e[0][0])%q,)] return [((x*c[0])%q,c[1]) for c in e] def exprpow(e,f): assert len(e) == 1 assert len(e[0]) == 1 assert len(f) == 1 assert len(f[0]) == 1 return [(pow(e[0][0],f[0][0],q),)] def exprdiv(e,f): if len(e) == 1 and len(e[0]) == 1: e,f = f,e assert len(f) == 1 assert len(f[0]) == 1 x = f[0][0] x = pow(x,q-2,q) # XXX: only for primes if len(e) == 1 and len(e[0]) == 1: return [((x*e[0][0])%q,)] return [((x*c[0])%q,c[1]) for c in e] def exprplus(e,f): if len(e) == 1 and len(e[0]) == 1: raise Exception('numerical %s + %s' % (e,f)) if len(f) == 1 and len(f[0]) == 1: raise Exception('numerical %s + %s' % (e,f)) coeffs = {} for c in e + f: if c[1] in coeffs: coeffs[c[1]] = (coeffs[c[1]] + c[0]) % q else: coeffs[c[1]] = c[0] return [(coeffs[v],v) for v in coeffs] def exprminus(e,f): if len(e) == 1 and len(e[0]) == 1: raise Exception('numerical exprminus') if len(f) == 1 and len(f[0]) == 1: raise Exception('numerical exprminus') coeffs = {} for c in e: if c[1] in coeffs: coeffs[c[1]] = (coeffs[c[1]] + c[0]) % q else: coeffs[c[1]] = c[0] for c in f: if c[1] in coeffs: coeffs[c[1]] = (coeffs[c[1]] - c[0]) % q else: coeffs[c[1]] = (-c[0]) % q return [(coeffs[v],v) for v in coeffs] def evaluate(e): if e[0] == 'number': if e[1] == '-': assert len(e) == 3 if int(e[2]) % q == 0: return [] return [((-int(e[2]))%q,)] assert len(e) == 2 if int(e[1]) % q == 0: return [] return [(int(e[1])%q,)] if e[0] == 'name': assert len(e) == 2 if not e[1] in variables: shortname = e[1] # if '_' in shortname: shortname = shortname[:e[1].index('_')] variables[e[1]] = [(1,shortname)] return variables[e[1]] if e[0] == 'power': assert len(e) == 3 return exprpow(evaluate(e[1]),evaluate(e[2])) if e[0] == 'term': assert len(e) == 4 if e[2] == '*': return exprtimes(evaluate(e[1]),evaluate(e[3])) if e[2] == '/': return exprdiv(evaluate(e[1]),evaluate(e[3])) if e[0] == 'expr': result = evaluate(e[1]) for j in range(2,len(e),2): if e[j] == '+': result = exprplus(result,evaluate(e[j+1])) elif e[j] == '-': result = exprminus(result,evaluate(e[j+1])) else: raise Exception('unknown expression %s' % e) return result raise Exception('unknown expression %s' % e) outputs = [] for p in program: if p[0] == 'assignment': assert p[1][0] == 'name' variables[p[1][1]] = evaluate(p[2]) # print('# %s = %s = %s' % (p[1][1],p[2],variables[p[1][1]])) if p[1][1].startswith('remout'): outputs += [p[1][1]] else: raise Exception('unknown statement %s' % p) def varnumber(c): s = ['0'] + [ch for ch in c[1] if ch in '0123456789'] return int(''.join(s)) for o in outputs: result = '%s = ' % o zero = True variables[o].sort(key=varnumber) for c in variables[o]: if c[0] != 0: if not zero: result += ' + ' inname = '_'.join(c[1].split('_')[:3]) result += '%d*%s' % (c[0],inname) zero = False if zero: result += '0' print(result)