-rwxr-xr-x 57986 nttcompiler-20220411/scripts/nttcompiler
#!/usr/bin/env python3 generatecassertions = False unrollingpragma = False vectorization = True countops = False reduce3 = False import sys for arg in sys.argv[1:]: if arg == 'generatecassertions': generatecassertions = True if arg == 'novectorization': vectorization = False if arg == 'countops': countops = True from collections.abc import Hashable class memoized(object): def __init__(self,func): self.func = func self.cache = {} self.__name__ = 'memoized:' + func.__name__ def __call__(self,*args): if not isinstance(args,Hashable): return self.func(*args) if not args in self.cache: self.cache[args] = self.func(*args) return self.cache[args] def bitreverse(n): result = [] while n > 0: result += [n&1] n >>= 1 return result @memoized def root(m,q): assert m >= 1 assert not m&(m-1) if m == 1: return 1 if m == 2: return -1 rr = root(m//2,q) for r in range(q): if (r*r-rr) % q == 0: return r raise ValueError('primitive %d root mod %d does not exist' % (m,q)) @memoized def qinv(q): assert q % 2**8 == 1 result = 2-q assert q*result % 2**16 == 1 return result @memoized def scaleup(r,q): r *= 2**16 r %= q if r > q//2: r -= q return r @memoized def scaleqinv(r,q): r *= qinv(q) r %= 2**16 if r >= 2**15: r -= 2**16 return r @memoized def qbits(q): result = 0 while 1<<result < q: result += 1 assert 1<<(result-1) < q assert q < 1<<result return result @memoized def qshift(q): return 1<<(17-qbits(q)) @memoized def qrecip(q): return int(round((16384<<qbits(q))/q)) @memoized def qround32(q): return int(round(32768/q)) @memoized def zeta(m,e,q): # root(m,q)^e mod q e %= m if e == 0: return 1 if e == 1: return root(m,q) if e%2 == 0: r = zeta(m,e//2,q) r *= r else: r = zeta(m,e-1,q) r *= root(m,q) r %= q if 2*r >= q: r -= q return r @memoized def scaledzeta(m,e,q): e %= m return scaleup(zeta(m,e,q),q) @memoized def qinvscaledzeta(m,e,q): e %= m return scaleqinv(scaledzeta(m,e,q),q) @memoized def scaledzetapow(m,e,powers,q): return tuple(scaledzeta(m,e*i,q) for i in range(powers)) @memoized def qinvscaledzetapow(m,e,powers,q): return tuple(qinvscaledzeta(m,e*i,q) for i in range(powers)) @memoized def mulmod_scaled_bound(x,m,e,q): # mulmod x is between -x and x y = scaledzeta(m,e,q) # mulmod y is exactly this xy = abs(x*y) # mulmod x*y is between -xy and xy lowb = (-xy)//2**16 highb = xy//2**16 # mulmod b is between lowb and highb lowd = -2**15 highd = 2**15-1 # mulmod d is between lowd and highd lowdq = lowd*q highdq = highd*q # mulmod d*q is between lowdq and highdq lowe = lowdq//2**16 highe = highdq//2**16 # mulmod e is between lowe and highe return max(abs(lowb-highe),abs(highb-lowe)) # XXX: test mulmod_scaled_bound def mulhi(x,y): assert x >= -2**15 assert x < 2**15 assert y >= -2**15 assert y < 2**15 z = (x*y)>>16 assert z >= -2**15 assert z < 2**15 return z def mulhrs(x,y): assert x >= -2**15 assert x < 2**15 assert y >= -2**15 assert y < 2**15 z = x*y z >>= 14 z += 1 z >>= 1 z %= 2**16 if z >= 2**15: z -= 2**16 assert z >= -2**15 assert z < 2**15 return z def mullo(x,y): assert x >= -2**15 assert x < 2**15 assert y >= -2**15 assert y < 2**15 z = x*y z %= 2**16 if z >= 2**15: z -= 2**16 assert z >= -2**15 assert z < 2**15 return z @memoized def reduce(x,q): if reduce3: y = mulhi(x,qrecip(q)) y = mulhrs(y,qshift(q)) y = mullo(y,q) else: y = mulhrs(x,qround32(q)) y = mullo(y,q) assert y%q == 0 return x-y @memoized def reduce_bound(q): return max(abs(reduce(x,q)) for x in range(-2**15,2**15)) class modulus: def setbound(self,bound): self.bound = bound pos = self.pos progressions = [] while len(pos) > 0: start = pos[0] if len(pos) == 1: progressions += [(start,start+1,1)] break spacing = pos[1]-pos[0] for j in range(len(pos)+1): if j >= len(pos): break if bound[j] != bound[0]: break if pos[j]-start != j*spacing: break progressions += [(start,start+j*spacing,spacing)] pos = pos[j:] bound = bound[j:] self.boundprogressions = progressions def __init__(self,d,m,e,pos,q,bound): r""" Modulus x^d - root(m)^e, coefficient j being at array position pos[j], coefficient j modulo q[k] being between +-bound[j][k]. INPUT: "d" - a positive integer "m" - a power of 2 (1 is allowed) "e" - an integer "pos" - a tuple of d distinct integers "q" - a tuple of allowed moduli "bound" - a tuple of d tuples of nonnegative integers """ d = int(d) assert d >= 1 m = int(m) assert m >= 1 assert not m&(m-1) e = int(e) e %= m pos = tuple(pos) assert len(pos) == d assert len(set(pos)) == d q = tuple(q) bound = tuple(tuple(b) for b in bound) assert len(bound) == d assert all(len(b) == len(q) for b in bound) while m > 1 and e%2 == 0: m //= 2 e //= 2 assert all(all(bp >= 0 and bp <= 32767 for bp in b) for b in bound) self.degree = d self.root = m self.rootpow = e self.pos = pos self.q = q progressions = [] while len(pos) > 0: start = pos[0] if len(pos) == 1: progressions += [(start,start+1,1)] break spacing = pos[1]-pos[0] for j in range(len(pos)+1): if j >= len(pos): break if pos[j]-start != j*spacing: break progressions += [(start,start+j*spacing,spacing)] pos = pos[j:] self.progressions = progressions self.setbound(bound) def assertions(self,N,transformpos): m = self.root e = self.rootpow sign = '-' if m >= 2 and 2*e >= m: assert m%2 == 0 e -= m//2 sign = '+' if e == 0: const = '1' elif e == 1: const = 'zeta%d'%m else: const = 'zeta%d^%d'%(m,e) comment = 'modulus x^%d%s%s' % (self.degree,sign,const) comment += ' pos' for start,stop,spacing in self.progressions: if spacing == 1: comment += ' %s:%s' % (start,stop) else: comment += ' %s:%s:%s' % (start,stop,spacing) comment += ' q %s' % ','.join('%d'%q for q in self.q) comment += ' bound' toprint = tuple(self.bound) numprinted = 0 while len(toprint) > 0: for j in range(len(toprint)+1): if j >= len(toprint): break if toprint[j] != toprint[0]: break numprinted += 1 if numprinted > 8: comment += ' ...' break comment += ' %s*(%s)' % (j,','.join('%d'%bp for bp in toprint[0])) toprint = toprint[j:] result = [('blankline',),('comment',comment)] result += [('assertranges',transformpos,list(self.bound),list(self.boundprogressions))] return result def fold(self,other,offset): r""" Return a modulus that merges this modulus with another modulus that is identical except for an offset of positions. INPUT: "other": the other modulus "offset": a nonzero integer, the distance of the other positions from this one """ assert self.degree == other.degree assert self.root == other.root assert self.rootpow == other.rootpow assert self.q == other.q assert other.pos == tuple(j+offset for j in self.pos) newbound = [] for b0,b1 in zip(self.bound,other.bound): b = tuple(max(b0i,b1i) for b0i,b1i in zip(b0,b1)) newbound += [b] return modulus(self.degree,self.root,self.rootpow,self.pos,self.q,newbound) def reduce(self,start,stop,spacing): r""" Return a tuple of moduli produced from this modulus by reducing x[j] for each j in range(start,stop,spacing). INPUT: "start": first j position "stop": after last j position "spacing": spacing between j positions """ J = set(range(start,stop,spacing)) newbound = [] d = self.degree for i in range(d): b = self.bound[i] if self.pos[i] in J: b = tuple(reduce_bound(q) for q in self.q) newbound += [b] result = modulus(d,self.root,self.rootpow,self.pos,self.q,newbound) return result, def twist(self,start,stop,spacing,bm,be): r""" Return a tuple of moduli produced from this modulus by multiplying x[j] for each j in range(start,stop,spacing) by 1,b,b^2,... where b = root(bm)^be. INPUT: "start": first j position "stop": after last j position "spacing": spacing between j positions "bm": a power of 2 (1 is allowed) "be": an integer """ assert bm >= 1 assert not bm&(bm-1) be %= bm J = set(range(start,stop,spacing)) S = set(self.pos) if len(J.intersection(S)) == 0: # this operation did not affect us return self, assert tuple(range(start,stop,spacing)) == self.pos d = self.degree m = self.root e = self.rootpow # happy to twist if (root(bm)^be)^d == root(m)^e if (be*d*m-e*bm) % (bm*m) != 0: raise ValueError('this modulus is x^%d-root(%d)^%d, not happy with twist by root(%d)^%d' % (d,m,e,bm,be)) newbound = [] for i in range(d): b = self.bound[i] newb = tuple(mulmod_scaled_bound(bp,bm,be*i,q) for bp,q in zip(b,self.q)) newbound += [newb] result = modulus(d,1,0,self.pos,self.q,newbound) return result, def butterfly(self,start,stop,spacing,offset,bm,be): r""" Return a tuple of moduli produced from this modulus by applying butterflies x[j],x[j+offset] = x[j]+x[j+offset]*b,x[j]-x[j+offset]*b for each j in range(start,stop,spacing) where b = root(bm)^be. Result will have one modulus if the butterflies don't touch this modulus. Result will have two moduli if the butterflies reduce this modulus mod x^(d/2)-b and x^(d/2)+b. INPUT: "start": first j position "stop": after last j position "spacing": spacing between j positions "offset": butterfly distance "bm": a power of 2 (1 is allowed) "be": an integer """ assert bm >= 1 assert not bm&(bm-1) be %= bm J = set(range(start,stop,spacing)) K = set(j+offset for j in range(start,stop,spacing)) assert len(J.intersection(K)) == 0 S = set(self.pos) if len(J.union(K).intersection(S)) == 0: # this operation did not affect us return self, # must have all of our coefficients involved assert J.union(K).intersection(S) == S d = self.degree if d%2: raise ValueError('butterfly requires even number of coefficients; this modulus has %d' % d) m = self.root e = self.rootpow # must have root(m)^e == root(bm)^(2*be) # i.e. root(m*bm)^(e*bm) == root(m*bm)^(2*be*bm) if (e*bm-2*be*m) % (m*bm) != 0: raise ValueError('this modulus has root(%d)^%d, cannot use butterfly for root(%d)^%d' % (m,e,bm,be)) for i in range(d//2): assert self.pos[i] in J assert self.pos[i+d//2] == self.pos[i]+offset oldbound0 = self.bound[:d//2] oldbound1 = self.bound[d//2:] newbound = [] for i in range(d//2): b0 = oldbound0[i] b1 = oldbound1[i] if be != 0: b1 = tuple(mulmod_scaled_bound(b1p,bm,be,q) for b1p,q in zip(b1,self.q)) newbound += [[b0p+b1p for b0p,b1p in zip(b0,b1)]] if be == 0: result0 = modulus(d//2,1,0,self.pos[:d//2],self.q,newbound) result1 = modulus(d//2,2,1,self.pos[d//2:],self.q,newbound) else: assert bm%2 == 0 result0 = modulus(d//2,bm,be,self.pos[:d//2],self.q,newbound) result1 = modulus(d//2,bm,be+bm//2,self.pos[d//2:],self.q,newbound) return result0,result1 # ----- STATE OF THE TRANSFORMATION # invariants: # moduli handles "reps" parallel batches; "reps" is run-time variable # each batch has 2^len(transformpos) transforms # each transform has N int16 variables # within each transform, virtual position sum_i 2^i v_i between 0 and N-1 # is stored at physical position sum_i 2^indexing[i] v_i # within each batch, transform sum_j 2^j t_j between 0 and 2^len(transformpos)-1 # is stored at physical position sum_j 2^transformpos[j] t_j class state: def __init__(self,N,qlist): assert N >= 1 assert not N&(N-1) transformpos = [] indexing = [] while 1<<len(indexing) < N: indexing += [len(indexing)] initialbounds = tuple(reduce_bound(q) for q in qlist) # XXX: consider variations allowing larger inputs # e.g. current NTT works for initialbounds = tuple(6824 for q in qlist) moduli = modulus(N,1,0,range(N),qlist,N*(initialbounds,)), self.N = N self.transformpos = transformpos self.indexing = indexing self.moduli = moduli self.initialN = N self.qlist = qlist self.qdata = {} self.code = [] self.output = '' def setbound(self,bound): for p in self.moduli: p.setbound(bound) def use(self,*key): if key[0][-1] == '*': howmany = key[-1] key = key[:-1] if key not in self.qdata or howmany > self.qdata[key]: self.qdata[key] = howmany else: self.qdata[key] = True if key == ('function','mulmod_scaled'): self.use('int16','q') self.use('function','mulhi') self.use('function','mullo') self.use('function','sub') if key == ('function','reduce'): if reduce3: self.use('int16','qrecip') self.use('int16','qshift') else: self.use('int16','qround32') self.use('int16','q') self.use('function','mulhi') self.use('function','mulhrs') self.use('function','mullo') self.use('function','sub') if key == ('function','mulmod_scaled_x16'): self.use('int16x16','q_x16') self.use('function','mulhi_x16') self.use('function','mullo_x16') self.use('function','sub_x16') if key == ('function','reduce_x16'): if reduce3: self.use('int16x16','qrecip_x16') self.use('int16x16','qshift_x16') else: self.use('int16x16','qround32_x16') self.use('int16x16','q_x16') self.use('function','mulhi_x16') self.use('function','mulhrs_x16') self.use('function','mullo_x16') self.use('function','sub_x16') def printused(self): qdata = self.qdata used = '' used += '// auto-generated; do not edit\n' used += '\n' if countops: used += '#include "ntt_ops_%d.h"\n' % self.initialN else: used += '#include "ntt_%d.h"\n' % self.initialN used += '\n' if generatecassertions: used += '#include <assert.h>\n' used += '\n' if vectorization: used += '#include <immintrin.h>\n' used += '\n' used += '#define _mm256_permute2x128_si256_lo(f0,f1) _mm256_permute2x128_si256(f0,f1,0x20)\n' used += '#define _mm256_permute2x128_si256_hi(f0,f1) _mm256_permute2x128_si256(f0,f1,0x31)\n' used += '#define int16x16 __m256i\n' if unrollingpragma: used += '\n' used += '#define PRAGMA_STRING(s) _Pragma(#s)\n' used += '#if defined __clang__\n' used += '#define UNROLL(n) PRAGMA_STRING(unroll n)\n' used += '#elif defined __GNUC__\n' used += '#define UNROLL(n) PRAGMA_STRING(GCC unroll n)\n' used += '#else\n' used += '#define UNROLL(n)\n' used += '#endif\n' used += '\n' used += 'typedef int16_t int16;\n' used += 'typedef int32_t int32;\n' if countops: used += '\n' used += '#include "ntt_ops.h"\n' used += '#define nummul_count(n) ntt_ops_mul += (n)\n' used += '#define numadd_count(n) ntt_ops_add += (n)\n' used += '#define nummul_x16_count(n) ntt_ops_mul_x16 += (n)\n' used += '#define numadd_x16_count(n) ntt_ops_add_x16 += (n)\n' used += '#define nummulmod_count(n) ntt_ops_mulmod += (n)\n' used += '#define numreduce_count(n) ntt_ops_reduce += (n)\n' firstq = True for q in self.qlist: assert q > 0 assert q%2 assert q < 2**15 used += '\n' if vectorization: used += 'static const int16 __attribute((aligned(32))) qdata_%d[] = {\n' % q else: used += 'static const int16 qdata_%d[] = {\n' % q pos = 0 for resourcetype in 'int16x16','int16','int16*': for key in sorted(qdata): if key[0] != resourcetype: continue resource = key[1] macro = resource for a in key[2:]: macro += '_%s'%a if resource == 'q': result = q elif resource == 'qrecip': result = qrecip(q) elif resource == 'qshift': result = qshift(q) elif resource == 'qround32': result = qround32(q) elif resource == 'scaledzeta': bm,be = key[2:] result = scaledzeta(bm,be,q) elif resource == 'qinvscaledzeta': bm,be = key[2:] result = qinvscaledzeta(bm,be,q) elif resource == 'scaledzeta_pow': bm,be = key[2:] powers = qdata[key] result = scaledzetapow(bm,be,powers,q) elif resource == 'qinvscaledzeta_pow': bm,be = key[2:] powers = qdata[key] result = qinvscaledzetapow(bm,be,powers,q) elif resource == 'q_x16': result = [q]*16 elif resource == 'qrecip_x16': result = [qrecip(q)]*16 elif resource == 'qshift_x16': result = [qshift(q)]*16 elif resource == 'qround32_x16': result = [qround32(q)]*16 elif resource == 'scaledzeta_x16': bm,be = key[2:] result = [scaledzeta(bm,be,q)]*16 elif resource == 'qinvscaledzeta_x16': bm,be = key[2:] result = [qinvscaledzeta(bm,be,q)]*16 elif resource.startswith('precomp'): bm,be = key[2:4] jlist = key[4:] result = [scaledzeta(bm,be*j,q) for j in jlist] elif resource.startswith('qinvprecomp'): bm,be = key[2:4] jlist = key[4:] result = [qinvscaledzeta(bm,be*j,q) for j in jlist] else: raise Exception(str((resourcetype,resource))) if not firstq: used += ' // %s\n' % macro else: if resourcetype == 'int16': used += '#define %s qdata[%d]\n' % (macro,pos) elif resourcetype == 'int16*': used += '#define %s (qdata+%d)\n' % (macro,pos) elif resourcetype == 'int16x16': used += '#define %s *(const int16x16 *)(qdata+%d)\n' % (macro,pos) if resourcetype in ('int16x16','int16*'): used += ' %s,\n' % ','.join('%s'%r for r in result) pos += len(result) else: used += ' %s,\n' % result pos += 1 used += '} ;\n' firstq = False if ('function','add') in qdata: used += """ static int16 add(int16 x,int16 y) { """ if countops: used += """\ numadd_count(1); """ used += """\ return x+y; } """ if ('function','sub') in qdata: used += """ static int16 sub(int16 x,int16 y) { """ if countops: used += """\ numadd_count(1); """ used += """\ return x-y; } """ if ('function','mullo') in qdata: used += """ static int16 mullo(int16 x,int16 y) { """ if countops: used += """\ nummul_count(1); """ used += """\ return x*y; } """ if ('function','mulhi') in qdata: used += """ static int16 mulhi(int16 x,int16 y) { """ if countops: used += """\ nummul_count(1); """ used += """\ return (x*(int32)y)>>16; } """ if ('function','mulhrs') in qdata: used += """ static int16 mulhrs(int16 x,int16 y) { """ if countops: used += """\ nummul_count(1); """ used += """\ return (x*(int32)y+16384)>>15; } """ if ('function','mulmod_scaled') in qdata: used += """ static int16 mulmod_scaled(int16 x,int16 y,int16 qinvy,const int16 *qdata) { """ if countops: used += """\ nummulmod_count(1); """ used += """\ int16 b = mulhi(x,y); int16 d = mullo(x,qinvy); int16 e = mulhi(d,q); return sub(b,e); } """ if ('function','reduce') in qdata: if reduce3: used += """ static int16 reduce(int16 x,const int16 *qdata) { """ if countops: used += """\ numreduce_count(1); """ used += """\ int16 y = mulhi(x,qrecip); y = mulhrs(y,qshift); y = mullo(y,q); return sub(x,y); } """ else: used += """ static int16 reduce(int16 x,const int16 *qdata) { """ if countops: used += """\ numreduce_count(1); """ used += """\ int16 y = mulhrs(x,qround32); y = mullo(y,q); return sub(x,y); } """ if ('function','add_x16') in qdata: used += """ static inline int16x16 add_x16(int16x16 a,int16x16 b) { """ if countops: used += """\ numadd_count(16); numadd_x16_count(1); """ used += """\ return _mm256_add_epi16(a,b); } """ if ('function','sub_x16') in qdata: used += """ static inline int16x16 sub_x16(int16x16 a,int16x16 b) { """ if countops: used += """\ numadd_count(16); numadd_x16_count(1); """ used += """\ return _mm256_sub_epi16(a,b); } """ if ('function','mulmod_scaled_x16') in qdata: used += """ static inline int16x16 mulmod_scaled_x16(int16x16 x,int16x16 y,int16x16 yqinv,const int16 *qdata) { """ if countops: used += """\ nummulmod_count(16); nummul_count(48); nummul_x16_count(3); """ used += """\ int16x16 b = _mm256_mulhi_epi16(x,y); int16x16 d = _mm256_mullo_epi16(x,yqinv); int16x16 e = _mm256_mulhi_epi16(d,q_x16); return sub_x16(b,e); } """ if ('function','reduce_x16') in qdata: if reduce3: used += """ static inline int16x16 reduce_x16(int16x16 x,const int16 *qdata) { """ if countops: used += """\ numreduce_count(16); nummul_count(48); nummul_x16_count(3); """ used += """\ int16x16 y = _mm256_mulhi_epi16(x,qrecip_x16); y = _mm256_mulhrs_epi16(y,qshift_x16); y = _mm256_mullo_epi16(y,q_x16); return sub_x16(x,y); } """ else: used += """ static inline int16x16 reduce_x16(int16x16 x,const int16 *qdata) { """ if countops: used += """\ numreduce_count(16); nummul_count(32); nummul_x16_count(2); """ used += """\ int16x16 y = _mm256_mulhrs_epi16(x,qround32_x16); y = _mm256_mullo_epi16(y,q_x16); return sub_x16(x,y); } """ self.output = used+self.output def indent(self): self.indentlevel += 1 def unindent(self): assert self.indentlevel > 0 self.indentlevel -= 1 def printline_noindent(self,line): assert '\n' not in line self.output += line+'\n' def printline(self,line): self.output += ' '*self.indentlevel+line+'\n' def codegen_blankline(self): self.output += '\n' def codegen_comment(self,comment): self.printline('// %s' % comment) def codegen_startntt(self,N): assert not self.infunction self.printline('') self.printline('static void ntt%d(int16 *f,long long reps,const int16 *qdata)' % N) self.printline('{') self.indent() self.infunction = True def codegen_stopntt_inv(self,N): assert not self.infunction self.printline('') self.printline('static void invntt%d(int16 *f,long long reps,const int16 *qdata)' % N) self.printline('{') self.indent() self.infunction = True assert self.initialN % (self.N<<len(self.transformpos)) == 0 scale = self.initialN//(self.N<<len(self.transformpos)) self.printline('reps *= %d;' % scale) def codegen_stopntt(self,N): assert self.infunction self.unindent() self.printline('}') for q in self.qlist: self.codegen_blankline() if countops: self.printline('void ntt_ops_%d_%d(int16 *f,long long reps)' % (N,q)) else: self.printline('void ntt_%d_%d(int16 *f,long long reps)' % (N,q)) self.printline('{') self.printline(' ntt%d(f,reps,qdata_%d);' % (N,q)) self.printline('}') self.infunction = False def codegen_startntt_inv(self,N): assert self.infunction self.unindent() self.printline('}') for q in self.qlist: self.codegen_blankline() if countops: self.printline('void ntt_ops_%d_%d_inv(int16 *f,long long reps)' % (N,q)) else: self.printline('void ntt_%d_%d_inv(int16 *f,long long reps)' % (N,q)) self.printline('{') self.printline(' invntt%d(f,reps,qdata_%d);' % (N,q)) self.printline('}') self.infunction = False def codegen_startbatch(self,batchsize): assert batchsize == self.batchsize assert self.infunction assert not self.inbatch self.printline('for (long long r = 0;r < reps;++r) {') self.indent() self.inbatch = True self.vectornextvar = [0]*((self.initialN+15)//16) self.vectorvar = [0]*((self.initialN+15)//16) # 0: latest data in f[16*j],f[16*j+1],...,f[16*j+15] # 'a': latest data in int16x16 'a%d'%j # 'b': latest data in int16x16 'b%d'%j # etc. # nextvar is 0 if 'a' is next, 1 if 'b' is next, etc. def codegen_stopbatch(self,batchsize): assert batchsize == self.batchsize assert self.infunction assert self.inbatch self.vector_spillall() assert self.vectorvar == [0]*((self.initialN+15)//16) self.printline('f += %d;' % batchsize) self.unindent() self.printline('}') self.printline('f -= %d*reps;' % batchsize) self.inbatch = False self.vectornextvar = None self.vectorvar = None codegen_startbatch_inv = codegen_stopbatch codegen_stopbatch_inv = codegen_startbatch def codegen_doublereps(self): assert self.infunction assert not self.inbatch assert self.batchsize%2 == 0 self.batchsize //= 2 self.printline('reps *= 2;') def codegen_doublereps_inv(self): assert self.infunction assert not self.inbatch self.batchsize *= 2 self.printline('reps /= 2;') def codegen_physical_unmap(self,indexing,transformpos): self.printline_noindent('#undef F') def codegen_physical_map(self,indexing,transformpos): physical = [] physical += [('t',i,transformpos[i]) for i in range(len(transformpos))] physical += [('v',i,indexing[i]) for i in range(len(indexing))] physical = '+'.join('((((%s)>>%d)&1)<<%d)' % (source,inbit,outbit) for source,inbit,outbit in physical) self.printline_noindent('#define F(t,v) f[%s]'%physical) codegen_physical_unmap_inv = codegen_physical_map codegen_physical_map_inv = codegen_physical_unmap def vector_nextvarname(self,j): v = self.vectornextvar[j] self.vectornextvar[j] = v+1 if v < 26: self.vectorvar[j] = 'abcdefghijklmnopqrstuvwxyz'[v] else: self.vectorvar[j] = 'L%d_'%v def vector_ensureloaded(self,j): if self.vectorvar[j] != 0: return self.vector_nextvarname(j) v = self.vectorvar[j] self.printline('int16x16 %s%d = _mm256_loadu_si256((int16x16 *) (f+%d));' % (v,j,j*16)) def vector_spillall(self): N = self.initialN for j in range((N+15)//16): v = self.vectorvar[j] if v != 0: self.printline('_mm256_storeu_si256((int16x16 *) (f+%d),%s%d);' % (j*16,v,j)) self.vectorvar[j] = 0 def selectorder(self,todo): return sorted(todo,key=bitreverse) def codegen_opt_physical_permute(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos): if not vectorization: return intelperm = False if len(perm) == 2 and perm[0] == 3 and perm[1] >= 4: intelperm = ('_mm256_permute2x128_si256_lo','_mm256_permute2x128_si256_hi',perm[1]) if len(perm) == 2 and perm[0] == 2 and perm[1] >= 4: intelperm = ('_mm256_unpacklo_epi64','_mm256_unpackhi_epi64',perm[1]) if len(perm) == 3 and perm[0] == 1 and perm[1] == 2 and perm[2] >= 4: intelperm = ('_mm256_unpacklo_epi32','_mm256_unpackhi_epi32',perm[2]) if len(perm) == 4 and perm[0] == 0 and perm[1] == 1 and perm[2] == 2 and perm[3] >= 4: intelperm = ('_mm256_unpacklo_epi16','_mm256_unpackhi_epi16',perm[3]) if not intelperm: return N = 1<<len(oldindexing) assert len(oldindexing) == len(newindexing) assert len(oldtransformpos) == len(newtransformpos) g0,g1,offset = intelperm result = [] for j in range(0,N<<len(oldtransformpos),16): if j&(1<<offset): continue result += [('vector_permute',j,j+(1<<offset),g0,g1)] result += [('physical_unmap',oldindexing,oldtransformpos)] result += [('physical_map',newindexing,newtransformpos)] return result def codegen_vector_permute(self,j0,j1,g0,g1): j0 //= 16 j1 //= 16 self.vector_ensureloaded(j0) self.vector_ensureloaded(j1) v0 = self.vectorvar[j0] v1 = self.vectorvar[j1] self.vector_nextvarname(j0) self.vector_nextvarname(j1) w0 = self.vectorvar[j0] w1 = self.vectorvar[j1] self.printline('int16x16 %s%d = %s(%s%d,%s%d);' % (w0,j0,g0,v0,j0,v1,j1)) self.printline('int16x16 %s%d = %s(%s%d,%s%d);' % (w1,j1,g1,v0,j0,v1,j1)) def codegen_vector_permute_inv(self,j0,j1,g0,g1): if (g0,g1) == ('_mm256_permute2x128_si256_lo','_mm256_permute2x128_si256_hi'): self.codegen_vector_permute(j0,j1,g0,g1) elif (g0,g1) == ('_mm256_unpacklo_epi64','_mm256_unpackhi_epi64'): self.codegen_vector_permute(j0,j1,g0,g1) elif (g0,g1) == ('_mm256_unpacklo_epi32','_mm256_unpackhi_epi32'): self.codegen_vector_permute(j0,j1,g0,g1) self.codegen_vector_permute(j0,j1,g0,g1) elif (g0,g1) == ('_mm256_unpacklo_epi16','_mm256_unpackhi_epi16'): self.codegen_vector_permute(j0,j1,g0,g1) self.codegen_vector_permute(j0,j1,g0,g1) self.codegen_vector_permute(j0,j1,g0,g1) else: raise ValueError('unknown permutation %s %s' % (g0,g1)) def codegen_physical_permute(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos): N = 1<<len(oldindexing) assert len(oldindexing) == len(newindexing) assert len(oldtransformpos) == len(newtransformpos) if vectorization: self.codegen_comment('nttcompiler did not vectorize this') self.printline('{') self.printline('int16 rearrange[%d];' % (N<<len(oldtransformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(oldtransformpos))) self.printline('for (long long t = 0;t < %d;++t)' % (1<<len(oldtransformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,N)) self.printline(' for (long long j = 0;j < %d;++j) rearrange[t*%d+j] = F(t,j);' % (N,N)) self.codegen_physical_unmap(oldindexing,oldtransformpos) self.codegen_physical_map(newindexing,newtransformpos) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(newtransformpos))) self.printline('for (long long t = 0;t < %d;++t)' % (1<<len(newtransformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,N)) self.printline(' for (long long j = 0;j < %d;++j) F(t,j) = rearrange[t*%d+j];' % (N,N)) self.printline('}') def codegen_physical_permute_inv(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos): self.codegen_physical_permute(reversed(perm),newindexing,newtransformpos,oldindexing,oldtransformpos) def codegen_assertranges(self,transformpos,bound,boundprogressions): if not generatecassertions: return self.vector_spillall() self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos))) for k in range(len(self.qlist)): self.printline(' if (q == %d) {' % self.qlist[k]) boundpos = 0 for start,stop,spacing in boundprogressions: boundmove = (stop-start)//spacing assert boundmove > 0 if boundmove == 1: self.printline(' assert(F(t,%d) >= -%d && F(t,%d) <= %d);' % (start,bound[boundpos][k],start,bound[boundpos][k])) else: assert all(bound[boundpos+j][k] == bound[boundpos][k] for j in range(boundmove)) self.printline(' for (long long j = %d;j != %d;j += %d)' % (start,stop,spacing)) self.printline(' assert(F(t,j) >= -%d && F(t,j) <= %d);' % (bound[boundpos][k],bound[boundpos][k])) boundpos += boundmove self.printline(' }') self.printline('}') def codegen_assertranges_inv(self,transformpos,bound,boundprogressions): pass # XXX def codegen_opt_reduce(self,start,stop,spacing,indexing,transformpos): vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos) if not vectorpos: return return [('vector_reduce',v) for v in self.selectorder(vectorpos)] def codegen_opt_reduce_ifforward(self,start,stop,spacing,indexing,transformpos): vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos) if not vectorpos: return return [('vector_reduce_ifforward',v) for v in self.selectorder(vectorpos)] def codegen_opt_reduce_ifreverse(self,start,stop,spacing,indexing,transformpos): vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos) if not vectorpos: return return [('vector_reduce_ifreverse',v) for v in self.selectorder(vectorpos)] def codegen_vector_reduce(self,v): self.use('function','reduce_x16') j = v//16 self.vector_ensureloaded(j) v = self.vectorvar[j] self.printline('%s%d = reduce_x16(%s%d,qdata);' % (v,j,v,j)) codegen_vector_reduce_inv = codegen_vector_reduce def codegen_reduce(self,start,stop,spacing,indexing,transformpos): self.vector_spillall() if vectorization: self.codegen_comment('nttcompiler did not vectorize this') self.use('function','reduce') if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos))) self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing)) self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing)) self.printline(' int16 f0 = F(t,j);') self.printline(' f0 = reduce(f0,qdata);') self.printline(' F(t,j) = f0;') self.printline(' }') self.printline('}') codegen_reduce_inv = codegen_reduce def codegen_skip(self,*args): pass codegen_reduce_ifforward = codegen_reduce codegen_reduce_ifforward_inv = codegen_skip codegen_vector_reduce_ifforward = codegen_vector_reduce codegen_vector_reduce_ifforward_inv = codegen_skip codegen_reduce_ifreverse = codegen_skip codegen_reduce_ifreverse_inv = codegen_reduce codegen_vector_reduce_ifreverse = codegen_skip codegen_vector_reduce_ifreverse_inv = codegen_vector_reduce def codegen_opt_twist(self,start,stop,spacing,bm,be,indexing,transformpos): vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos) if not vectorpos: return result = [] for v in self.selectorder(vectorpos): result += [('vector_twist',v,bm,be,*vectorpos[v])] return result def codegen_vector_twist(self,v,bm,be,*indices): be %= bm self.use('function','mulmod_scaled_x16') self.use('int16x16','precomp',bm,be,*indices) self.use('int16x16','qinvprecomp',bm,be,*indices) precompname = 'precomp_%d_%d'%(bm,be) for e in indices: precompname += '_%s'%e j = v//16 self.vector_ensureloaded(j) v = self.vectorvar[j] self.printline('%s%d = mulmod_scaled_x16(%s%d,%s,qinv%s,qdata);' % (v,j,v,j,precompname,precompname)) def codegen_vector_twist_inv(self,v,bm,be,*indices): self.codegen_vector_twist(v,bm,-be,*indices) def codegen_twist(self,start,stop,spacing,bm,be,indexing,transformpos): self.vector_spillall() if vectorization: self.codegen_comment('nttcompiler did not vectorize this') be %= bm self.use('function','mulmod_scaled') self.use('int16*','scaledzeta_pow',bm,be,(stop-start)//spacing) self.use('int16*','qinvscaledzeta_pow',bm,be,(stop-start)//spacing) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos))) self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing)) self.printline(' for (long long j = %d;j != %d;++j) {' % (0,(stop-start)//spacing)) self.printline(' int16 f0 = F(t,%d%+d*j);'% (start,spacing)) self.printline(' f0 = mulmod_scaled(f0,scaledzeta_pow_%d_%d[j],qinvscaledzeta_pow_%d_%d[j],qdata);'% (bm,be,bm,be)) self.printline(' F(t,%d%+d*j) = f0;'% (start,spacing)) self.printline(' }') self.printline('}') def codegen_twist_inv(self,start,stop,spacing,bm,be,indexing,transformpos): self.codegen_twist(start,stop,spacing,bm,-be,indexing,transformpos) def codegen_opt_butterfly(self,start,stop,spacing,offset,bm,be,indexing,transformpos): if not vectorization: return def F(t,j): assert t >= 0 assert t < 1<<len(transformpos) assert j >= 0 assert j < 1<<len(indexing) result = 0 result += sum(((t>>i)&1)<<transformpos[i] for i in range(len(transformpos))) result += sum(((j>>i)&1)<<indexing[i] for i in range(len(indexing))) return result todo = {} todovector = {} for t in range(1<<len(transformpos)): for j in range(start,stop,spacing): physical0 = F(t,j) physical1 = F(t,j+offset) assert physical0 not in todo assert physical1 not in todo todo[physical0] = (0,physical1) todo[physical1] = (1,physical0) physical0vector = physical0&~15 physical1vector = physical1&~15 # self.codegen_comment('physical butterfly %d,%d vectors %d,%d' % (physical0,physical1,physical0vector,physical1vector)) if physical0vector not in todovector: todovector[physical0vector] = (0,physical1vector) if physical1vector not in todovector: todovector[physical1vector] = (1,physical0vector) if todovector[physical0vector] != (0,physical1vector): return if todovector[physical1vector] != (1,physical0vector): return for v in todovector: side,vx = todovector[v] for i in range(16): if v+i not in todo or todo[v+i] != (side,vx+i): return result = [] for v0 in self.selectorder(todovector): side,v1 = todovector[v0] if side != 0: continue result += [('vector_butterfly',v0,v1,bm,be)] return result def codegen_vector_butterfly(self,v0,v1,bm,be): self.use('function','add_x16') self.use('function','sub_x16') be %= bm if be != 0: self.use('function','mulmod_scaled_x16') self.use('int16x16','scaledzeta_x16',bm,be) self.use('int16x16','qinvscaledzeta_x16',bm,be) j0 = v0//16 self.vector_ensureloaded(j0) v0 = self.vectorvar[j0] j1 = v1//16 self.vector_ensureloaded(j1) v1 = self.vectorvar[j1] if be != 0: self.printline('%s%d = mulmod_scaled_x16(%s%d,scaledzeta_x16_%d_%d,qinvscaledzeta_x16_%d_%d,qdata);' % (v1,j1,v1,j1,bm,be,bm,be)) self.vector_nextvarname(j0) self.vector_nextvarname(j1) w0 = self.vectorvar[j0] w1 = self.vectorvar[j1] self.printline('int16x16 %s%d = add_x16(%s%d,%s%d);' % (w0,j0,v0,j0,v1,j1)) self.printline('int16x16 %s%d = sub_x16(%s%d,%s%d);' % (w1,j1,v0,j0,v1,j1)) def codegen_vector_butterfly_inv(self,v0,v1,bm,be): self.use('function','add_x16') self.use('function','sub_x16') be = (-be)%bm if be != 0: self.use('function','mulmod_scaled_x16') self.use('int16x16','scaledzeta_x16',bm,be) self.use('int16x16','qinvscaledzeta_x16',bm,be) j0 = v0//16 self.vector_ensureloaded(j0) v0 = self.vectorvar[j0] j1 = v1//16 self.vector_ensureloaded(j1) v1 = self.vectorvar[j1] self.vector_nextvarname(j0) self.vector_nextvarname(j1) w0 = self.vectorvar[j0] w1 = self.vectorvar[j1] self.printline('int16x16 %s%d = add_x16(%s%d,%s%d);' % (w0,j0,v0,j0,v1,j1)) self.printline('int16x16 %s%d = sub_x16(%s%d,%s%d);' % (w1,j1,v0,j0,v1,j1)) if be != 0: self.printline('%s%d = mulmod_scaled_x16(%s%d,scaledzeta_x16_%d_%d,qinvscaledzeta_x16_%d_%d,qdata);' % (w1,j1,w1,j1,bm,be,bm,be)) def codegen_butterfly(self,start,stop,spacing,offset,bm,be,indexing,transformpos): self.vector_spillall() if vectorization: self.codegen_comment('nttcompiler did not vectorize this') self.use('function','add') self.use('function','sub') if be != 0: self.use('int16','scaledzeta',bm,be) self.use('int16','qinvscaledzeta',bm,be) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos))) self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing)) self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing)) self.printline(' int16 f0 = F(t,j);') self.printline(' int16 f1 = F(t,j%+d);' % offset) if be != 0: self.printline(' f1 = mulmod_scaled(f1,scaledzeta_%d_%d,qinvscaledzeta_%d_%d,qdata);' % (bm,be,bm,be)) self.printline(' F(t,j) = add(f0,f1);') self.printline(' F(t,j%+d) = sub(f0,f1);' % offset) self.printline(' }') self.printline('}') def codegen_butterfly_inv(self,start,stop,spacing,offset,bm,be,indexing,transformpos): self.vector_spillall() if vectorization: self.codegen_comment('nttcompiler did not vectorize this') be = -be be %= bm self.use('function','add') self.use('function','sub') if be != 0: self.use('int16','scaledzeta',bm,be) self.use('int16','qinvscaledzeta',bm,be) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos))) self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos))) if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing)) self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing)) self.printline(' int16 f0 = F(t,j);') self.printline(' int16 f1 = F(t,j%+d);' % offset) self.printline(' F(t,j) = add(f0,f1);') self.printline(' f1 = sub(f0,f1);') if be != 0: self.printline(' f1 = mulmod_scaled(f1,scaledzeta_%d_%d,qinvscaledzeta_%d_%d,qdata);' % (bm,be,bm,be)) self.printline(' F(t,j%+d) = f1;' % offset) self.printline(' }') self.printline('}') def codegen(self): N = self.initialN self.indentlevel = 0 self.batchsize = N self.infunction = False self.inbatch = False self.codegen_blankline() self.codegen_comment('----- codegen pass 1') self.codegen_comment('') for c in self.code: if c[0] == 'comment': self.codegen_comment('// %s' % str(*c[1:])) elif c[0] == 'blankline': self.codegen_comment('') elif c[0] == 'assertranges': self.codegen_comment('assertranges ...') else: self.codegen_comment(' '.join(str(ci) for ci in c)) # ----- initial optimization of each instruction newcode = [] for c in self.code: if c[0] in ('comment','blankline'): continue if c[0] == 'assertranges': if not generatecassertions: continue opt = 'codegen_opt_%s'%c[0] if opt in type(self).__dict__: opt = self.__getattribute__(opt) opt = opt(*c[1:]) if opt: newcode += opt continue newcode += [c] self.code = newcode # ----- eliminate maps if purely vectorized purelyvectorized = True for c in self.code: if c[0] in ('startntt','stopntt','startbatch','stopbatch','doublereps'): continue if c[0] in ('physical_unmap','physical_map'): continue if c[0].startswith('vector_'): continue purelyvectorized = False break if purelyvectorized: newcode = [] for c in self.code: if c[0] in ('physical_unmap','physical_map'): continue newcode += [c] self.code = newcode # ----- reshuffle vectorops = ('vector_butterfly','vector_reduce','vector_reduce_ifforward','vector_reduce_ifreverse','vector_twist','vector_permute') for i in range(1,len(self.code)): j = i while j > 0: wantswap = False if self.code[j-1][0] in ('physical_unmap','physical_map'): if self.code[j][0] in ('stopbatch','startbatch','doublereps'): wantswap = True if 0: # doesn't seem as good as the more ad-hoc partitioning below if self.code[j-1][0] in vectorops: if self.code[j][0] in vectorops: inputs0 = self.code[j-1][1:3] if self.code[j-1][0] in ('vector_butterfly','vector_permute') else self.code[j-1][1:2] inputs1 = self.code[j][1:3] if self.code[j][0] in ('vector_butterfly','vector_permute') else self.code[j][1:2] if len(set(inputs0).intersection(set(inputs1))) == 0: if bitreverse(inputs1[-1]) < bitreverse(inputs0[-1]): wantswap = True if not wantswap: break self.code[j-1],self.code[j] = self.code[j],self.code[j-1] j -= 1 # ----- try to partition vector instructions into big lanes code = self.code newcode = [] batchsize = self.initialN while len(code) > 0: for j in range(len(code)): if code[j][0] not in vectorops: break if j == 0: if code[0][0] == 'doublereps': batchsize //= 2 newcode += code[:1] code = code[1:] else: vectorcode,code = code[:j],code[j:] okbiglane = 1 biglane = 1 while True: ok = True for c in vectorcode: inputs = c[1:3] if c[0] in ('vector_butterfly','vector_permute') else c[1:2] if len(set(i%biglane for i in inputs)) > 1: ok = False break if not ok: break okbiglane = biglane biglane *= 2 if biglane > self.initialN: break if okbiglane > 64: okbiglane = 1 for r in range(okbiglane): if r != 0: newcode += [('stopbatch',batchsize),('startbatch',batchsize)] for c in vectorcode: if c[1]%okbiglane == r: newcode += [c] self.code = newcode # ----- peephole progress = True while progress: progress = False for i in range(2,len(self.code)): if self.code[i-2][0] == 'physical_unmap': if self.code[i-1][0] == 'physical_map': if self.code[i][0] == 'physical_unmap': self.code = self.code[:i-1]+self.code[i+1:] progress = True break for i in range(1,len(self.code)): if self.code[i-1][0] == 'startbatch': if self.code[i][0] == 'stopbatch': self.code = self.code[:i-1]+self.code[i+1:] progress = True break # ----- done with pass 2 self.codegen_blankline() self.codegen_comment('----- codegen pass 2') self.codegen_comment('') for c in self.code: if c[0] == 'assertranges': self.codegen_comment('assertranges ...') else: self.codegen_comment(' '.join(str(ci) for ci in c)) for c in self.code: if c[0] == 'assertranges': self.codegen_comment('assertranges ...') else: self.codegen_comment(' '.join(str(a) for a in c)) self.__getattribute__('codegen_'+c[0])(*c[1:]) # XXX: track ranges through inverse for c in reversed(self.code): if c[0] == 'assertranges': self.codegen_comment('assertranges ...') else: self.codegen_comment('inv '+' '.join(str(a) for a in c)) self.__getattribute__('codegen_'+c[0]+'_inv')(*c[1:]) assert self.indentlevel == 0 assert not self.inbatch assert not self.infunction def physical_permute(self,*perm): r""" Reshuffle data. Data at physical position sum_i 2^i p_i moves to physical position sum_i 2^NEW(i) p_i where NEW maps perm[0] to perm[1], maps perm[1] to perm[2], etc., maps perm[-1] to perm[0], and fixes any other inputs. """ N = self.N indexing = self.indexing transformpos = self.transformpos NEW = {} for i in range(len(indexing)): NEW[indexing[i]] = indexing[i] for i in range(len(transformpos)): NEW[transformpos[i]] = transformpos[i] for i in range(len(perm)): assert perm[i] in NEW for i in range(len(perm)): NEW[perm[i]] = perm[(i+1)%len(perm)] assert set(indexing+transformpos) == set(NEW[i] for i in indexing).union(set(NEW[i] for i in transformpos)) # old virtual position sum_i 2^i v_i between 0 and N-1 # is stored at physical position sum_i 2^indexing[i] v_i # and in a moment will be stored at # physical position sum_i 2^NEW(indexing[i]) v_i newindexing = [NEW[indexing[i]] for i in range(len(indexing))] newtransformpos = [NEW[transformpos[i]] for i in range(len(transformpos))] self.code += [('physical_permute',perm,tuple(indexing),tuple(transformpos),tuple(newindexing),tuple(newtransformpos))] self.indexing = newindexing self.transformpos = newtransformpos def assertions(self): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert 1<<len(indexing) == N self.code += [('comment','transform size %d' % N)] self.code += [('comment','transform indexing %s' % indexing)] self.code += [('comment','transforms per batch %d' % (1<<len(transformpos)))] self.code += [('comment','batch indexing %s' % transformpos)] self.code += [('comment','total batch size %d' % (N<<len(transformpos)))] if generatecassertions: self.use('int16','q') for p in moduli: self.code += p.assertions(N,transformpos) def startbatch(self): N = self.N transformpos = self.transformpos self.code += [('startbatch',N<<len(transformpos))] def stopbatch(self): N = self.N transformpos = self.transformpos self.code += [('stopbatch',N<<len(transformpos))] def startntt(self): self.code += [('startntt',self.initialN)] self.startbatch() self.code += [('comment','----- PRECONDITIONS')] self.code += [('physical_map',tuple(self.indexing),tuple(self.transformpos))] self.assertions() def startlayer(self,layernum): self.code += [('blankline',)] self.code += [('comment','----- LAYER %d' % layernum)] def stoplayer(self,layernum): self.code += [('blankline',)] self.code += [('comment','----- POSTCONDITIONS AFTER LAYER %d' % layernum)] self.assertions() def announce(self,functionname,args): self.code += [('blankline',)] self.code += [('comment','%s(%s)' % (functionname,','.join(str(a) for a in args)))] def stopntt(self): self.stopbatch() self.code += [('physical_unmap',tuple(self.indexing),tuple(self.transformpos))] self.code += [('stopntt',self.initialN)] def vectorpositions(self,start,stop,spacing,indexing,transformpos): if not vectorization: return False if (stop-start)%spacing: return False def F(t,j): assert t >= 0 assert t < 1<<len(transformpos) assert j >= 0 assert j < 1<<len(indexing) result = 0 result += sum(((t>>i)&1)<<transformpos[i] for i in range(len(transformpos))) result += sum(((j>>i)&1)<<indexing[i] for i in range(len(indexing))) return result emap = {} for t in range(1<<len(transformpos)): for e in range((stop-start)//spacing): j = start+e*spacing assert F(t,j) not in emap emap[F(t,j)] = e todo = sorted(emap) result = {} while len(todo) > 0: v = todo[0] if v%16 or len(todo)%16 or todo[:16] != [v+i for i in range(16)]: return False assert v not in result result[v] = [emap[v+i] for i in range(16)] todo = todo[16:] return result def reduce(self,start,stop,spacing): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert (stop-start)%spacing == 0 assert all(j >= 0 for j in range(start,stop,spacing)) assert all(j < N for j in range(start,stop,spacing)) self.code += [('reduce',start,stop,spacing,tuple(indexing),tuple(transformpos))] newmoduli = () for p in moduli: newmoduli += p.reduce(start,stop,spacing) self.moduli = newmoduli def reduce_ifforward(self,start,stop,spacing): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert (stop-start)%spacing == 0 assert all(j >= 0 for j in range(start,stop,spacing)) assert all(j < N for j in range(start,stop,spacing)) self.code += [('reduce_ifforward',start,stop,spacing,tuple(indexing),tuple(transformpos))] newmoduli = () for p in moduli: newmoduli += p.reduce(start,stop,spacing) self.moduli = newmoduli def reduce_ifreverse(self,start,stop,spacing): N = self.N indexing = self.indexing transformpos = self.transformpos assert (stop-start)%spacing == 0 assert all(j >= 0 for j in range(start,stop,spacing)) assert all(j < N for j in range(start,stop,spacing)) self.code += [('reduce_ifreverse',start,stop,spacing,tuple(indexing),tuple(transformpos))] def twist(self,start,stop,spacing,bm,be): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert (stop-start)%spacing == 0 be %= bm assert all(j >= 0 for j in range(start,stop,spacing)) assert all(j < N for j in range(start,stop,spacing)) self.code += [('twist',start,stop,spacing,bm,be,tuple(indexing),tuple(transformpos))] newmoduli = () for p in moduli: newmoduli += p.twist(start,stop,spacing,bm,be) self.moduli = newmoduli def twists(self): moduli = self.moduli reducetodo = [] twisttodo = [] for M in moduli: d = M.degree startpos = M.pos[0] assert list(M.pos) == list(range(startpos,startpos+d)) bm = M.root be = M.rootpow if be+be > bm: be -= bm if be == 0: reducetodo += [(startpos,startpos+d,1)] else: twisttodo += [(startpos,startpos+d,1,d*bm,be)] for t in reducetodo: self.reduce(*t) for t in twisttodo: self.twist(*t) def butterfly(self,start,stop,spacing,offset,bm,be): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert (stop-start)%spacing == 0 be %= bm assert all(j >= 0 for j in range(start,stop,spacing)) assert all(j < N for j in range(start,stop,spacing)) assert all(j+offset >= 0 for j in range(start,stop,spacing)) assert all(j+offset < N for j in range(start,stop,spacing)) self.code += [('butterfly',start,stop,spacing,offset,bm,be,tuple(indexing),tuple(transformpos))] newmoduli = () for p in moduli: newmoduli += p.butterfly(start,stop,spacing,offset,bm,be) self.moduli = newmoduli def butterflies(self): moduli = self.moduli todo = [] for M in moduli: d = M.degree assert d%2 == 0 startpos = M.pos[0] assert list(M.pos) == list(range(startpos,startpos+d)) bm = M.root be = M.rootpow if be+be > bm: be -= bm if be != 0: bm *= 2 todo += [(startpos,startpos+d//2,1,d//2,bm,be)] for t in todo: self.butterfly(*t) def fold(self,offset): N = self.N indexing = self.indexing transformpos = self.transformpos moduli = self.moduli assert len(moduli)%2 == 0 self.code += [('physical_unmap',tuple(indexing),tuple(transformpos))] newmoduli = () for i in range(len(moduli)//2): p0 = moduli[i] p1 = moduli[i+len(moduli)//2] newmoduli += (p0.fold(p1,offset),) moduli = newmoduli transformpos += [indexing[-1]] indexing = indexing[:-1] N //= 2 transformpos.sort() self.N = N self.transformpos = transformpos self.indexing = indexing self.moduli = moduli self.code += [('physical_map',tuple(indexing),tuple(transformpos))] def nextbatch(self): self.stopbatch() self.startbatch() def halfbatch(self): # double reps # reducing 2^(len(indexing)+len(transformpos)) transforms/batch # to 2^(len(indexing)+len(transformpos)-1) transforms/batch indexing = self.indexing transformpos = self.transformpos self.code += [('physical_unmap',tuple(indexing),tuple(transformpos))] self.stopbatch() assert len(indexing)+len(transformpos)-1 in transformpos transformpos.remove(len(indexing)+len(transformpos)-1) self.code += [('doublereps',)] self.transformpos = transformpos self.startbatch() self.code += [('physical_map',tuple(indexing),tuple(transformpos))] # ----- SOURCE PROGRAM def doit(): strategy = [[]] qlist = [] N = None for line in sys.stdin: line = line.strip().split() if len(line) == 0: if len(strategy[-1]) > 0: strategy += [[]] continue if line[0] == 'N': N = int(line[1]) continue if line[0] == 'q': qlist += [int(line[1])] continue strategy[-1] += [(line[0],tuple(map(int,line[1:])))] if strategy[-1] == []: strategy = strategy[:-1] S = state(N,qlist) S.startntt() layernum = 1 for layer in strategy: S.startlayer(layernum) for functionname,args in layer: S.announce(functionname,args) S.__getattribute__(functionname)(*args) S.stoplayer(layernum) layernum += 1 S.stopntt() S.codegen() S.printused() sys.stdout.write(S.output) doit()