-rwxr-xr-x 6705 nttcompiler-20220411/scripts/doublecheck
#!/usr/bin/env python3 import os import sys from functools import reduce from random import randrange import subprocess N = int(sys.argv[1]) ntt = sys.argv[2] L,H = None,None if len(sys.argv) > 4: L = int(sys.argv[3]) H = int(sys.argv[4]) from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) pmnumber = Word(nums+"-") name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + Literal('constant').suppress() + lparen + number + comma + number + rparen ).setParseAction(group('constant')) for binary in ['__sub__','__rshift__','__lshift__','LShR','mulhi16','mulhrs16']: assignment |= ( name + equal + Literal(binary).suppress() + lparen + name + comma + name + rparen ).setParseAction(group(binary)) assignment |= ( name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignment |= ( name + equal + Literal('SignExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('SignExt')) assignment |= ( name + equal + Literal('ZeroExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('ZeroExt')) for manyary in ['Concat','__or__','__and__','__add__','__mul__']: assignment |= ( name + equal + Literal(manyary).suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group(manyary)) assignment |= ( name + equal + name).setParseAction(group('copy') ) rangecheck = ( Literal('assertsignedminmax').suppress() + lparen + name + comma + pmnumber + comma + pmnumber + rparen ).setParseAction(group('assertsignedminmax')) assignments = ZeroOrMore(assignment | rangecheck) + StringEnd() inputcopy = '' program = '' for line in sys.stdin: inputcopy += line if line.startswith('#'): continue if line.startswith('rem'): continue program += line+'\n' program = assignments.parseString(program) program = list(program) def op_constant(x,y): return x,y def op___add__(*args): assert len(args) > 0 b = args[0][0] assert all(x[0] == b for x in args) return b,reduce((lambda s,t:(s+t)%(2**b)),(x[1] for x in args)) def op___mul__(*args): assert len(args) > 0 b = args[0][0] assert all(x[0] == b for x in args) return b,reduce((lambda s,t:(s*t)%(2**b)),(x[1] for x in args)) def op___and__(*args): assert len(args) > 0 b = args[0][0] assert all(x[0] == b for x in args) return b,reduce((lambda s,t:s&t),(x[1] for x in args)) def op___or__(*args): assert len(args) > 0 b = args[0][0] assert all(x[0] == b for x in args) return b,reduce((lambda s,t:s|t),(x[1] for x in args)) def op___sub__(x,y): assert x[0] == y[0] return x[0],(x[1] - y[1]) % (2**x[0]) def op_assertsignedminmax(x,L,H): b = x[0] flip = 2**(b-1) xsigned = (x[1] ^ flip) - flip assert xsigned >= L assert xsigned <= H def op_mulhi16(x,y): b = 16 assert x[0] == b assert y[0] == b flip = 2**(b-1) xsigned = (x[1] ^ flip) - flip ysigned = (y[1] ^ flip) - flip zsigned = (xsigned * ysigned) >> 16 assert (zsigned + flip) ^ flip == zsigned % 2**b return b,zsigned % 2**b def op_mulhrs16(x,y): b = 16 assert x[0] == b assert y[0] == b flip = 2**(b-1) xsigned = (x[1] ^ flip) - flip ysigned = (y[1] ^ flip) - flip zsigned = (((xsigned * ysigned) >> 14) + 1) >> 1 # XXX: fails if xsigned and ysigned are both -32768 # assert (zsigned + flip) ^ flip == zsigned % 2**b return b,zsigned % 2**b def op___lshift__(x,y): assert x[0] == y[0] return x[0],(x[1] << y[1]) % (2**x[0]) def op_LShR(x,y): # unsigned right shift b = x[0] assert b == y[0] xunsigned = x[1] yunsigned = y[1] zunsigned = xunsigned >> yunsigned return b,zunsigned def op___rshift__(x,y): # signed right shift b = x[0] assert b == y[0] flip = 2**(b-1) xsigned = (x[1] ^ flip) - flip ysigned = (y[1] ^ flip) - flip assert 0 <= ysigned assert ysigned < b zsigned = xsigned >> ysigned return b,(zsigned + flip) ^ flip def op_Concat(*args): pos,value = 0,0 for arg in reversed(args): pos,value = pos + arg[0],value + (arg[1] << pos) return pos,value def op_Extract(x,top,bot): assert x[0] > top assert top >= bot assert bot >= 0 return top + 1 - bot,((x[1] & ((2 << top) - 1)) >> bot) def op_SignExt(x,bits): b,val = x if val & (2**(b-1)): return b + bits,val + 2**(b+bits) - 2**b return b + bits,val def op_ZeroExt(x,bits): b,val = x return b + bits,val def op_copy(x): return x input = {} # e.g., input[37] for in_static_f_37_99_16 output = {} # e.g., output[37] for out_static_f_37 value = {} # e.g., value['in_static_f_37_99_16'] def evaluate(x): if x not in value: assert x.startswith('in') z = x.split('_') z = [int(zj) for zj in z if zj.isnumeric()] assert len(z) >= 2 # first is key, last is numbits key = z[0] assert key not in input b = int(z[-1]) if H is not None: v = randrange(L,H+1) else: v = randrange(2**b) v %= 2**b input[key] = b,v value[x] = b,v return value[x] for p in program: if p[0] == 'assertsignedminmax': args = [evaluate(pj) for pj in p[1:2]] + [int(pj) for pj in p[2:]] op_assertsignedminmax(*args) continue if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'constant': args = [int(pj) for pj in p[2:]] elif p[0] in ('Extract','SignExt','ZeroExt'): args = [evaluate(pj) for pj in p[2:3]] + [int(pj) for pj in p[3:]] else: args = [evaluate(pj) for pj in p[2:]] op = getattr(sys.modules[__name__],'op_'+p[0]) value[p[1]] = op(*args) if p[1].startswith('out'): z = p[1].split('_') z = [int(zj) for zj in z if zj.isnumeric()] assert len(z) >= 1 key = z[0] output[key] = value[p[1]] assert sorted(input) == list(range(N)) assert sorted(output) == list(range(N)) inputstr = '' for j in range(N): assert input[j][0] == 16 v = input[j][1] if v >= 2**15: v -= 2**16 inputstr += '%d\n'%v proc = subprocess.Popen(ntt,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True) executestr,err = proc.communicate(input=inputstr) outputstr = '' for j in range(N): assert output[j][0] == 16 v = output[j][1] if v >= 2**15: v -= 2**16 outputstr += '%d\n'%v # and now the real test... assert executestr == outputstr # ok, allowed to pass along the input sys.stdout.write(inputcopy)