-rwxr-xr-x 57986 nttcompiler-20220411/scripts/nttcompiler
#!/usr/bin/env python3
generatecassertions = False
unrollingpragma = False
vectorization = True
countops = False
reduce3 = False
import sys
for arg in sys.argv[1:]:
if arg == 'generatecassertions':
generatecassertions = True
if arg == 'novectorization':
vectorization = False
if arg == 'countops':
countops = True
from collections.abc import Hashable
class memoized(object):
def __init__(self,func):
self.func = func
self.cache = {}
self.__name__ = 'memoized:' + func.__name__
def __call__(self,*args):
if not isinstance(args,Hashable):
return self.func(*args)
if not args in self.cache:
self.cache[args] = self.func(*args)
return self.cache[args]
def bitreverse(n):
result = []
while n > 0:
result += [n&1]
n >>= 1
return result
@memoized
def root(m,q):
assert m >= 1
assert not m&(m-1)
if m == 1:
return 1
if m == 2:
return -1
rr = root(m//2,q)
for r in range(q):
if (r*r-rr) % q == 0:
return r
raise ValueError('primitive %d root mod %d does not exist' % (m,q))
@memoized
def qinv(q):
assert q % 2**8 == 1
result = 2-q
assert q*result % 2**16 == 1
return result
@memoized
def scaleup(r,q):
r *= 2**16
r %= q
if r > q//2: r -= q
return r
@memoized
def scaleqinv(r,q):
r *= qinv(q)
r %= 2**16
if r >= 2**15: r -= 2**16
return r
@memoized
def qbits(q):
result = 0
while 1<<result < q: result += 1
assert 1<<(result-1) < q
assert q < 1<<result
return result
@memoized
def qshift(q):
return 1<<(17-qbits(q))
@memoized
def qrecip(q):
return int(round((16384<<qbits(q))/q))
@memoized
def qround32(q):
return int(round(32768/q))
@memoized
def zeta(m,e,q): # root(m,q)^e mod q
e %= m
if e == 0:
return 1
if e == 1:
return root(m,q)
if e%2 == 0:
r = zeta(m,e//2,q)
r *= r
else:
r = zeta(m,e-1,q)
r *= root(m,q)
r %= q
if 2*r >= q: r -= q
return r
@memoized
def scaledzeta(m,e,q):
e %= m
return scaleup(zeta(m,e,q),q)
@memoized
def qinvscaledzeta(m,e,q):
e %= m
return scaleqinv(scaledzeta(m,e,q),q)
@memoized
def scaledzetapow(m,e,powers,q):
return tuple(scaledzeta(m,e*i,q) for i in range(powers))
@memoized
def qinvscaledzetapow(m,e,powers,q):
return tuple(qinvscaledzeta(m,e*i,q) for i in range(powers))
@memoized
def mulmod_scaled_bound(x,m,e,q):
# mulmod x is between -x and x
y = scaledzeta(m,e,q) # mulmod y is exactly this
xy = abs(x*y) # mulmod x*y is between -xy and xy
lowb = (-xy)//2**16
highb = xy//2**16
# mulmod b is between lowb and highb
lowd = -2**15
highd = 2**15-1
# mulmod d is between lowd and highd
lowdq = lowd*q
highdq = highd*q
# mulmod d*q is between lowdq and highdq
lowe = lowdq//2**16
highe = highdq//2**16
# mulmod e is between lowe and highe
return max(abs(lowb-highe),abs(highb-lowe))
# XXX: test mulmod_scaled_bound
def mulhi(x,y):
assert x >= -2**15
assert x < 2**15
assert y >= -2**15
assert y < 2**15
z = (x*y)>>16
assert z >= -2**15
assert z < 2**15
return z
def mulhrs(x,y):
assert x >= -2**15
assert x < 2**15
assert y >= -2**15
assert y < 2**15
z = x*y
z >>= 14
z += 1
z >>= 1
z %= 2**16
if z >= 2**15: z -= 2**16
assert z >= -2**15
assert z < 2**15
return z
def mullo(x,y):
assert x >= -2**15
assert x < 2**15
assert y >= -2**15
assert y < 2**15
z = x*y
z %= 2**16
if z >= 2**15: z -= 2**16
assert z >= -2**15
assert z < 2**15
return z
@memoized
def reduce(x,q):
if reduce3:
y = mulhi(x,qrecip(q))
y = mulhrs(y,qshift(q))
y = mullo(y,q)
else:
y = mulhrs(x,qround32(q))
y = mullo(y,q)
assert y%q == 0
return x-y
@memoized
def reduce_bound(q):
return max(abs(reduce(x,q)) for x in range(-2**15,2**15))
class modulus:
def setbound(self,bound):
self.bound = bound
pos = self.pos
progressions = []
while len(pos) > 0:
start = pos[0]
if len(pos) == 1:
progressions += [(start,start+1,1)]
break
spacing = pos[1]-pos[0]
for j in range(len(pos)+1):
if j >= len(pos): break
if bound[j] != bound[0]: break
if pos[j]-start != j*spacing: break
progressions += [(start,start+j*spacing,spacing)]
pos = pos[j:]
bound = bound[j:]
self.boundprogressions = progressions
def __init__(self,d,m,e,pos,q,bound):
r"""
Modulus x^d - root(m)^e,
coefficient j being at array position pos[j],
coefficient j modulo q[k] being between +-bound[j][k].
INPUT:
"d" - a positive integer
"m" - a power of 2 (1 is allowed)
"e" - an integer
"pos" - a tuple of d distinct integers
"q" - a tuple of allowed moduli
"bound" - a tuple of d tuples of nonnegative integers
"""
d = int(d)
assert d >= 1
m = int(m)
assert m >= 1
assert not m&(m-1)
e = int(e)
e %= m
pos = tuple(pos)
assert len(pos) == d
assert len(set(pos)) == d
q = tuple(q)
bound = tuple(tuple(b) for b in bound)
assert len(bound) == d
assert all(len(b) == len(q) for b in bound)
while m > 1 and e%2 == 0:
m //= 2
e //= 2
assert all(all(bp >= 0 and bp <= 32767 for bp in b) for b in bound)
self.degree = d
self.root = m
self.rootpow = e
self.pos = pos
self.q = q
progressions = []
while len(pos) > 0:
start = pos[0]
if len(pos) == 1:
progressions += [(start,start+1,1)]
break
spacing = pos[1]-pos[0]
for j in range(len(pos)+1):
if j >= len(pos): break
if pos[j]-start != j*spacing: break
progressions += [(start,start+j*spacing,spacing)]
pos = pos[j:]
self.progressions = progressions
self.setbound(bound)
def assertions(self,N,transformpos):
m = self.root
e = self.rootpow
sign = '-'
if m >= 2 and 2*e >= m:
assert m%2 == 0
e -= m//2
sign = '+'
if e == 0:
const = '1'
elif e == 1:
const = 'zeta%d'%m
else:
const = 'zeta%d^%d'%(m,e)
comment = 'modulus x^%d%s%s' % (self.degree,sign,const)
comment += ' pos'
for start,stop,spacing in self.progressions:
if spacing == 1:
comment += ' %s:%s' % (start,stop)
else:
comment += ' %s:%s:%s' % (start,stop,spacing)
comment += ' q %s' % ','.join('%d'%q for q in self.q)
comment += ' bound'
toprint = tuple(self.bound)
numprinted = 0
while len(toprint) > 0:
for j in range(len(toprint)+1):
if j >= len(toprint): break
if toprint[j] != toprint[0]: break
numprinted += 1
if numprinted > 8:
comment += ' ...'
break
comment += ' %s*(%s)' % (j,','.join('%d'%bp for bp in toprint[0]))
toprint = toprint[j:]
result = [('blankline',),('comment',comment)]
result += [('assertranges',transformpos,list(self.bound),list(self.boundprogressions))]
return result
def fold(self,other,offset):
r"""
Return a modulus
that merges this modulus with another modulus
that is identical except for an offset of positions.
INPUT:
"other": the other modulus
"offset": a nonzero integer, the distance of the other positions from this one
"""
assert self.degree == other.degree
assert self.root == other.root
assert self.rootpow == other.rootpow
assert self.q == other.q
assert other.pos == tuple(j+offset for j in self.pos)
newbound = []
for b0,b1 in zip(self.bound,other.bound):
b = tuple(max(b0i,b1i) for b0i,b1i in zip(b0,b1))
newbound += [b]
return modulus(self.degree,self.root,self.rootpow,self.pos,self.q,newbound)
def reduce(self,start,stop,spacing):
r"""
Return a tuple of moduli
produced from this modulus
by reducing x[j] for each j in range(start,stop,spacing).
INPUT:
"start": first j position
"stop": after last j position
"spacing": spacing between j positions
"""
J = set(range(start,stop,spacing))
newbound = []
d = self.degree
for i in range(d):
b = self.bound[i]
if self.pos[i] in J:
b = tuple(reduce_bound(q) for q in self.q)
newbound += [b]
result = modulus(d,self.root,self.rootpow,self.pos,self.q,newbound)
return result,
def twist(self,start,stop,spacing,bm,be):
r"""
Return a tuple of moduli
produced from this modulus
by multiplying x[j] for each j in range(start,stop,spacing)
by 1,b,b^2,...
where b = root(bm)^be.
INPUT:
"start": first j position
"stop": after last j position
"spacing": spacing between j positions
"bm": a power of 2 (1 is allowed)
"be": an integer
"""
assert bm >= 1
assert not bm&(bm-1)
be %= bm
J = set(range(start,stop,spacing))
S = set(self.pos)
if len(J.intersection(S)) == 0:
# this operation did not affect us
return self,
assert tuple(range(start,stop,spacing)) == self.pos
d = self.degree
m = self.root
e = self.rootpow
# happy to twist if (root(bm)^be)^d == root(m)^e
if (be*d*m-e*bm) % (bm*m) != 0:
raise ValueError('this modulus is x^%d-root(%d)^%d, not happy with twist by root(%d)^%d' % (d,m,e,bm,be))
newbound = []
for i in range(d):
b = self.bound[i]
newb = tuple(mulmod_scaled_bound(bp,bm,be*i,q) for bp,q in zip(b,self.q))
newbound += [newb]
result = modulus(d,1,0,self.pos,self.q,newbound)
return result,
def butterfly(self,start,stop,spacing,offset,bm,be):
r"""
Return a tuple of moduli
produced from this modulus
by applying butterflies
x[j],x[j+offset] = x[j]+x[j+offset]*b,x[j]-x[j+offset]*b
for each j in range(start,stop,spacing)
where b = root(bm)^be.
Result will have one modulus
if the butterflies don't touch this modulus.
Result will have two moduli
if the butterflies reduce this modulus mod x^(d/2)-b and x^(d/2)+b.
INPUT:
"start": first j position
"stop": after last j position
"spacing": spacing between j positions
"offset": butterfly distance
"bm": a power of 2 (1 is allowed)
"be": an integer
"""
assert bm >= 1
assert not bm&(bm-1)
be %= bm
J = set(range(start,stop,spacing))
K = set(j+offset for j in range(start,stop,spacing))
assert len(J.intersection(K)) == 0
S = set(self.pos)
if len(J.union(K).intersection(S)) == 0:
# this operation did not affect us
return self,
# must have all of our coefficients involved
assert J.union(K).intersection(S) == S
d = self.degree
if d%2:
raise ValueError('butterfly requires even number of coefficients; this modulus has %d' % d)
m = self.root
e = self.rootpow
# must have root(m)^e == root(bm)^(2*be)
# i.e. root(m*bm)^(e*bm) == root(m*bm)^(2*be*bm)
if (e*bm-2*be*m) % (m*bm) != 0:
raise ValueError('this modulus has root(%d)^%d, cannot use butterfly for root(%d)^%d' % (m,e,bm,be))
for i in range(d//2):
assert self.pos[i] in J
assert self.pos[i+d//2] == self.pos[i]+offset
oldbound0 = self.bound[:d//2]
oldbound1 = self.bound[d//2:]
newbound = []
for i in range(d//2):
b0 = oldbound0[i]
b1 = oldbound1[i]
if be != 0:
b1 = tuple(mulmod_scaled_bound(b1p,bm,be,q) for b1p,q in zip(b1,self.q))
newbound += [[b0p+b1p for b0p,b1p in zip(b0,b1)]]
if be == 0:
result0 = modulus(d//2,1,0,self.pos[:d//2],self.q,newbound)
result1 = modulus(d//2,2,1,self.pos[d//2:],self.q,newbound)
else:
assert bm%2 == 0
result0 = modulus(d//2,bm,be,self.pos[:d//2],self.q,newbound)
result1 = modulus(d//2,bm,be+bm//2,self.pos[d//2:],self.q,newbound)
return result0,result1
# ----- STATE OF THE TRANSFORMATION
# invariants:
# moduli handles "reps" parallel batches; "reps" is run-time variable
# each batch has 2^len(transformpos) transforms
# each transform has N int16 variables
# within each transform, virtual position sum_i 2^i v_i between 0 and N-1
# is stored at physical position sum_i 2^indexing[i] v_i
# within each batch, transform sum_j 2^j t_j between 0 and 2^len(transformpos)-1
# is stored at physical position sum_j 2^transformpos[j] t_j
class state:
def __init__(self,N,qlist):
assert N >= 1
assert not N&(N-1)
transformpos = []
indexing = []
while 1<<len(indexing) < N:
indexing += [len(indexing)]
initialbounds = tuple(reduce_bound(q) for q in qlist)
# XXX: consider variations allowing larger inputs
# e.g. current NTT works for initialbounds = tuple(6824 for q in qlist)
moduli = modulus(N,1,0,range(N),qlist,N*(initialbounds,)),
self.N = N
self.transformpos = transformpos
self.indexing = indexing
self.moduli = moduli
self.initialN = N
self.qlist = qlist
self.qdata = {}
self.code = []
self.output = ''
def setbound(self,bound):
for p in self.moduli:
p.setbound(bound)
def use(self,*key):
if key[0][-1] == '*':
howmany = key[-1]
key = key[:-1]
if key not in self.qdata or howmany > self.qdata[key]:
self.qdata[key] = howmany
else:
self.qdata[key] = True
if key == ('function','mulmod_scaled'):
self.use('int16','q')
self.use('function','mulhi')
self.use('function','mullo')
self.use('function','sub')
if key == ('function','reduce'):
if reduce3:
self.use('int16','qrecip')
self.use('int16','qshift')
else:
self.use('int16','qround32')
self.use('int16','q')
self.use('function','mulhi')
self.use('function','mulhrs')
self.use('function','mullo')
self.use('function','sub')
if key == ('function','mulmod_scaled_x16'):
self.use('int16x16','q_x16')
self.use('function','mulhi_x16')
self.use('function','mullo_x16')
self.use('function','sub_x16')
if key == ('function','reduce_x16'):
if reduce3:
self.use('int16x16','qrecip_x16')
self.use('int16x16','qshift_x16')
else:
self.use('int16x16','qround32_x16')
self.use('int16x16','q_x16')
self.use('function','mulhi_x16')
self.use('function','mulhrs_x16')
self.use('function','mullo_x16')
self.use('function','sub_x16')
def printused(self):
qdata = self.qdata
used = ''
used += '// auto-generated; do not edit\n'
used += '\n'
if countops:
used += '#include "ntt_ops_%d.h"\n' % self.initialN
else:
used += '#include "ntt_%d.h"\n' % self.initialN
used += '\n'
if generatecassertions:
used += '#include <assert.h>\n'
used += '\n'
if vectorization:
used += '#include <immintrin.h>\n'
used += '\n'
used += '#define _mm256_permute2x128_si256_lo(f0,f1) _mm256_permute2x128_si256(f0,f1,0x20)\n'
used += '#define _mm256_permute2x128_si256_hi(f0,f1) _mm256_permute2x128_si256(f0,f1,0x31)\n'
used += '#define int16x16 __m256i\n'
if unrollingpragma:
used += '\n'
used += '#define PRAGMA_STRING(s) _Pragma(#s)\n'
used += '#if defined __clang__\n'
used += '#define UNROLL(n) PRAGMA_STRING(unroll n)\n'
used += '#elif defined __GNUC__\n'
used += '#define UNROLL(n) PRAGMA_STRING(GCC unroll n)\n'
used += '#else\n'
used += '#define UNROLL(n)\n'
used += '#endif\n'
used += '\n'
used += 'typedef int16_t int16;\n'
used += 'typedef int32_t int32;\n'
if countops:
used += '\n'
used += '#include "ntt_ops.h"\n'
used += '#define nummul_count(n) ntt_ops_mul += (n)\n'
used += '#define numadd_count(n) ntt_ops_add += (n)\n'
used += '#define nummul_x16_count(n) ntt_ops_mul_x16 += (n)\n'
used += '#define numadd_x16_count(n) ntt_ops_add_x16 += (n)\n'
used += '#define nummulmod_count(n) ntt_ops_mulmod += (n)\n'
used += '#define numreduce_count(n) ntt_ops_reduce += (n)\n'
firstq = True
for q in self.qlist:
assert q > 0
assert q%2
assert q < 2**15
used += '\n'
if vectorization:
used += 'static const int16 __attribute((aligned(32))) qdata_%d[] = {\n' % q
else:
used += 'static const int16 qdata_%d[] = {\n' % q
pos = 0
for resourcetype in 'int16x16','int16','int16*':
for key in sorted(qdata):
if key[0] != resourcetype: continue
resource = key[1]
macro = resource
for a in key[2:]:
macro += '_%s'%a
if resource == 'q':
result = q
elif resource == 'qrecip':
result = qrecip(q)
elif resource == 'qshift':
result = qshift(q)
elif resource == 'qround32':
result = qround32(q)
elif resource == 'scaledzeta':
bm,be = key[2:]
result = scaledzeta(bm,be,q)
elif resource == 'qinvscaledzeta':
bm,be = key[2:]
result = qinvscaledzeta(bm,be,q)
elif resource == 'scaledzeta_pow':
bm,be = key[2:]
powers = qdata[key]
result = scaledzetapow(bm,be,powers,q)
elif resource == 'qinvscaledzeta_pow':
bm,be = key[2:]
powers = qdata[key]
result = qinvscaledzetapow(bm,be,powers,q)
elif resource == 'q_x16':
result = [q]*16
elif resource == 'qrecip_x16':
result = [qrecip(q)]*16
elif resource == 'qshift_x16':
result = [qshift(q)]*16
elif resource == 'qround32_x16':
result = [qround32(q)]*16
elif resource == 'scaledzeta_x16':
bm,be = key[2:]
result = [scaledzeta(bm,be,q)]*16
elif resource == 'qinvscaledzeta_x16':
bm,be = key[2:]
result = [qinvscaledzeta(bm,be,q)]*16
elif resource.startswith('precomp'):
bm,be = key[2:4]
jlist = key[4:]
result = [scaledzeta(bm,be*j,q) for j in jlist]
elif resource.startswith('qinvprecomp'):
bm,be = key[2:4]
jlist = key[4:]
result = [qinvscaledzeta(bm,be*j,q) for j in jlist]
else:
raise Exception(str((resourcetype,resource)))
if not firstq:
used += ' // %s\n' % macro
else:
if resourcetype == 'int16':
used += '#define %s qdata[%d]\n' % (macro,pos)
elif resourcetype == 'int16*':
used += '#define %s (qdata+%d)\n' % (macro,pos)
elif resourcetype == 'int16x16':
used += '#define %s *(const int16x16 *)(qdata+%d)\n' % (macro,pos)
if resourcetype in ('int16x16','int16*'):
used += ' %s,\n' % ','.join('%s'%r for r in result)
pos += len(result)
else:
used += ' %s,\n' % result
pos += 1
used += '} ;\n'
firstq = False
if ('function','add') in qdata:
used += """
static int16 add(int16 x,int16 y)
{
"""
if countops: used += """\
numadd_count(1);
"""
used += """\
return x+y;
}
"""
if ('function','sub') in qdata:
used += """
static int16 sub(int16 x,int16 y)
{
"""
if countops: used += """\
numadd_count(1);
"""
used += """\
return x-y;
}
"""
if ('function','mullo') in qdata:
used += """
static int16 mullo(int16 x,int16 y)
{
"""
if countops: used += """\
nummul_count(1);
"""
used += """\
return x*y;
}
"""
if ('function','mulhi') in qdata:
used += """
static int16 mulhi(int16 x,int16 y)
{
"""
if countops: used += """\
nummul_count(1);
"""
used += """\
return (x*(int32)y)>>16;
}
"""
if ('function','mulhrs') in qdata:
used += """
static int16 mulhrs(int16 x,int16 y)
{
"""
if countops: used += """\
nummul_count(1);
"""
used += """\
return (x*(int32)y+16384)>>15;
}
"""
if ('function','mulmod_scaled') in qdata:
used += """
static int16 mulmod_scaled(int16 x,int16 y,int16 qinvy,const int16 *qdata)
{
"""
if countops: used += """\
nummulmod_count(1);
"""
used += """\
int16 b = mulhi(x,y);
int16 d = mullo(x,qinvy);
int16 e = mulhi(d,q);
return sub(b,e);
}
"""
if ('function','reduce') in qdata:
if reduce3:
used += """
static int16 reduce(int16 x,const int16 *qdata)
{
"""
if countops: used += """\
numreduce_count(1);
"""
used += """\
int16 y = mulhi(x,qrecip);
y = mulhrs(y,qshift);
y = mullo(y,q);
return sub(x,y);
}
"""
else:
used += """
static int16 reduce(int16 x,const int16 *qdata)
{
"""
if countops: used += """\
numreduce_count(1);
"""
used += """\
int16 y = mulhrs(x,qround32);
y = mullo(y,q);
return sub(x,y);
}
"""
if ('function','add_x16') in qdata:
used += """
static inline int16x16 add_x16(int16x16 a,int16x16 b)
{
"""
if countops: used += """\
numadd_count(16);
numadd_x16_count(1);
"""
used += """\
return _mm256_add_epi16(a,b);
}
"""
if ('function','sub_x16') in qdata:
used += """
static inline int16x16 sub_x16(int16x16 a,int16x16 b)
{
"""
if countops: used += """\
numadd_count(16);
numadd_x16_count(1);
"""
used += """\
return _mm256_sub_epi16(a,b);
}
"""
if ('function','mulmod_scaled_x16') in qdata:
used += """
static inline int16x16 mulmod_scaled_x16(int16x16 x,int16x16 y,int16x16 yqinv,const int16 *qdata)
{
"""
if countops: used += """\
nummulmod_count(16);
nummul_count(48);
nummul_x16_count(3);
"""
used += """\
int16x16 b = _mm256_mulhi_epi16(x,y);
int16x16 d = _mm256_mullo_epi16(x,yqinv);
int16x16 e = _mm256_mulhi_epi16(d,q_x16);
return sub_x16(b,e);
}
"""
if ('function','reduce_x16') in qdata:
if reduce3:
used += """
static inline int16x16 reduce_x16(int16x16 x,const int16 *qdata)
{
"""
if countops: used += """\
numreduce_count(16);
nummul_count(48);
nummul_x16_count(3);
"""
used += """\
int16x16 y = _mm256_mulhi_epi16(x,qrecip_x16);
y = _mm256_mulhrs_epi16(y,qshift_x16);
y = _mm256_mullo_epi16(y,q_x16);
return sub_x16(x,y);
}
"""
else:
used += """
static inline int16x16 reduce_x16(int16x16 x,const int16 *qdata)
{
"""
if countops: used += """\
numreduce_count(16);
nummul_count(32);
nummul_x16_count(2);
"""
used += """\
int16x16 y = _mm256_mulhrs_epi16(x,qround32_x16);
y = _mm256_mullo_epi16(y,q_x16);
return sub_x16(x,y);
}
"""
self.output = used+self.output
def indent(self):
self.indentlevel += 1
def unindent(self):
assert self.indentlevel > 0
self.indentlevel -= 1
def printline_noindent(self,line):
assert '\n' not in line
self.output += line+'\n'
def printline(self,line):
self.output += ' '*self.indentlevel+line+'\n'
def codegen_blankline(self):
self.output += '\n'
def codegen_comment(self,comment):
self.printline('// %s' % comment)
def codegen_startntt(self,N):
assert not self.infunction
self.printline('')
self.printline('static void ntt%d(int16 *f,long long reps,const int16 *qdata)' % N)
self.printline('{')
self.indent()
self.infunction = True
def codegen_stopntt_inv(self,N):
assert not self.infunction
self.printline('')
self.printline('static void invntt%d(int16 *f,long long reps,const int16 *qdata)' % N)
self.printline('{')
self.indent()
self.infunction = True
assert self.initialN % (self.N<<len(self.transformpos)) == 0
scale = self.initialN//(self.N<<len(self.transformpos))
self.printline('reps *= %d;' % scale)
def codegen_stopntt(self,N):
assert self.infunction
self.unindent()
self.printline('}')
for q in self.qlist:
self.codegen_blankline()
if countops:
self.printline('void ntt_ops_%d_%d(int16 *f,long long reps)' % (N,q))
else:
self.printline('void ntt_%d_%d(int16 *f,long long reps)' % (N,q))
self.printline('{')
self.printline(' ntt%d(f,reps,qdata_%d);' % (N,q))
self.printline('}')
self.infunction = False
def codegen_startntt_inv(self,N):
assert self.infunction
self.unindent()
self.printline('}')
for q in self.qlist:
self.codegen_blankline()
if countops:
self.printline('void ntt_ops_%d_%d_inv(int16 *f,long long reps)' % (N,q))
else:
self.printline('void ntt_%d_%d_inv(int16 *f,long long reps)' % (N,q))
self.printline('{')
self.printline(' invntt%d(f,reps,qdata_%d);' % (N,q))
self.printline('}')
self.infunction = False
def codegen_startbatch(self,batchsize):
assert batchsize == self.batchsize
assert self.infunction
assert not self.inbatch
self.printline('for (long long r = 0;r < reps;++r) {')
self.indent()
self.inbatch = True
self.vectornextvar = [0]*((self.initialN+15)//16)
self.vectorvar = [0]*((self.initialN+15)//16)
# 0: latest data in f[16*j],f[16*j+1],...,f[16*j+15]
# 'a': latest data in int16x16 'a%d'%j
# 'b': latest data in int16x16 'b%d'%j
# etc.
# nextvar is 0 if 'a' is next, 1 if 'b' is next, etc.
def codegen_stopbatch(self,batchsize):
assert batchsize == self.batchsize
assert self.infunction
assert self.inbatch
self.vector_spillall()
assert self.vectorvar == [0]*((self.initialN+15)//16)
self.printline('f += %d;' % batchsize)
self.unindent()
self.printline('}')
self.printline('f -= %d*reps;' % batchsize)
self.inbatch = False
self.vectornextvar = None
self.vectorvar = None
codegen_startbatch_inv = codegen_stopbatch
codegen_stopbatch_inv = codegen_startbatch
def codegen_doublereps(self):
assert self.infunction
assert not self.inbatch
assert self.batchsize%2 == 0
self.batchsize //= 2
self.printline('reps *= 2;')
def codegen_doublereps_inv(self):
assert self.infunction
assert not self.inbatch
self.batchsize *= 2
self.printline('reps /= 2;')
def codegen_physical_unmap(self,indexing,transformpos):
self.printline_noindent('#undef F')
def codegen_physical_map(self,indexing,transformpos):
physical = []
physical += [('t',i,transformpos[i]) for i in range(len(transformpos))]
physical += [('v',i,indexing[i]) for i in range(len(indexing))]
physical = '+'.join('((((%s)>>%d)&1)<<%d)' % (source,inbit,outbit) for source,inbit,outbit in physical)
self.printline_noindent('#define F(t,v) f[%s]'%physical)
codegen_physical_unmap_inv = codegen_physical_map
codegen_physical_map_inv = codegen_physical_unmap
def vector_nextvarname(self,j):
v = self.vectornextvar[j]
self.vectornextvar[j] = v+1
if v < 26:
self.vectorvar[j] = 'abcdefghijklmnopqrstuvwxyz'[v]
else:
self.vectorvar[j] = 'L%d_'%v
def vector_ensureloaded(self,j):
if self.vectorvar[j] != 0: return
self.vector_nextvarname(j)
v = self.vectorvar[j]
self.printline('int16x16 %s%d = _mm256_loadu_si256((int16x16 *) (f+%d));' % (v,j,j*16))
def vector_spillall(self):
N = self.initialN
for j in range((N+15)//16):
v = self.vectorvar[j]
if v != 0:
self.printline('_mm256_storeu_si256((int16x16 *) (f+%d),%s%d);' % (j*16,v,j))
self.vectorvar[j] = 0
def selectorder(self,todo):
return sorted(todo,key=bitreverse)
def codegen_opt_physical_permute(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos):
if not vectorization: return
intelperm = False
if len(perm) == 2 and perm[0] == 3 and perm[1] >= 4:
intelperm = ('_mm256_permute2x128_si256_lo','_mm256_permute2x128_si256_hi',perm[1])
if len(perm) == 2 and perm[0] == 2 and perm[1] >= 4:
intelperm = ('_mm256_unpacklo_epi64','_mm256_unpackhi_epi64',perm[1])
if len(perm) == 3 and perm[0] == 1 and perm[1] == 2 and perm[2] >= 4:
intelperm = ('_mm256_unpacklo_epi32','_mm256_unpackhi_epi32',perm[2])
if len(perm) == 4 and perm[0] == 0 and perm[1] == 1 and perm[2] == 2 and perm[3] >= 4:
intelperm = ('_mm256_unpacklo_epi16','_mm256_unpackhi_epi16',perm[3])
if not intelperm: return
N = 1<<len(oldindexing)
assert len(oldindexing) == len(newindexing)
assert len(oldtransformpos) == len(newtransformpos)
g0,g1,offset = intelperm
result = []
for j in range(0,N<<len(oldtransformpos),16):
if j&(1<<offset): continue
result += [('vector_permute',j,j+(1<<offset),g0,g1)]
result += [('physical_unmap',oldindexing,oldtransformpos)]
result += [('physical_map',newindexing,newtransformpos)]
return result
def codegen_vector_permute(self,j0,j1,g0,g1):
j0 //= 16
j1 //= 16
self.vector_ensureloaded(j0)
self.vector_ensureloaded(j1)
v0 = self.vectorvar[j0]
v1 = self.vectorvar[j1]
self.vector_nextvarname(j0)
self.vector_nextvarname(j1)
w0 = self.vectorvar[j0]
w1 = self.vectorvar[j1]
self.printline('int16x16 %s%d = %s(%s%d,%s%d);' % (w0,j0,g0,v0,j0,v1,j1))
self.printline('int16x16 %s%d = %s(%s%d,%s%d);' % (w1,j1,g1,v0,j0,v1,j1))
def codegen_vector_permute_inv(self,j0,j1,g0,g1):
if (g0,g1) == ('_mm256_permute2x128_si256_lo','_mm256_permute2x128_si256_hi'):
self.codegen_vector_permute(j0,j1,g0,g1)
elif (g0,g1) == ('_mm256_unpacklo_epi64','_mm256_unpackhi_epi64'):
self.codegen_vector_permute(j0,j1,g0,g1)
elif (g0,g1) == ('_mm256_unpacklo_epi32','_mm256_unpackhi_epi32'):
self.codegen_vector_permute(j0,j1,g0,g1)
self.codegen_vector_permute(j0,j1,g0,g1)
elif (g0,g1) == ('_mm256_unpacklo_epi16','_mm256_unpackhi_epi16'):
self.codegen_vector_permute(j0,j1,g0,g1)
self.codegen_vector_permute(j0,j1,g0,g1)
self.codegen_vector_permute(j0,j1,g0,g1)
else:
raise ValueError('unknown permutation %s %s' % (g0,g1))
def codegen_physical_permute(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos):
N = 1<<len(oldindexing)
assert len(oldindexing) == len(newindexing)
assert len(oldtransformpos) == len(newtransformpos)
if vectorization: self.codegen_comment('nttcompiler did not vectorize this')
self.printline('{')
self.printline('int16 rearrange[%d];' % (N<<len(oldtransformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(oldtransformpos)))
self.printline('for (long long t = 0;t < %d;++t)' % (1<<len(oldtransformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,N))
self.printline(' for (long long j = 0;j < %d;++j) rearrange[t*%d+j] = F(t,j);' % (N,N))
self.codegen_physical_unmap(oldindexing,oldtransformpos)
self.codegen_physical_map(newindexing,newtransformpos)
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(newtransformpos)))
self.printline('for (long long t = 0;t < %d;++t)' % (1<<len(newtransformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,N))
self.printline(' for (long long j = 0;j < %d;++j) F(t,j) = rearrange[t*%d+j];' % (N,N))
self.printline('}')
def codegen_physical_permute_inv(self,perm,oldindexing,oldtransformpos,newindexing,newtransformpos):
self.codegen_physical_permute(reversed(perm),newindexing,newtransformpos,oldindexing,oldtransformpos)
def codegen_assertranges(self,transformpos,bound,boundprogressions):
if not generatecassertions: return
self.vector_spillall()
self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos)))
for k in range(len(self.qlist)):
self.printline(' if (q == %d) {' % self.qlist[k])
boundpos = 0
for start,stop,spacing in boundprogressions:
boundmove = (stop-start)//spacing
assert boundmove > 0
if boundmove == 1:
self.printline(' assert(F(t,%d) >= -%d && F(t,%d) <= %d);' % (start,bound[boundpos][k],start,bound[boundpos][k]))
else:
assert all(bound[boundpos+j][k] == bound[boundpos][k] for j in range(boundmove))
self.printline(' for (long long j = %d;j != %d;j += %d)' % (start,stop,spacing))
self.printline(' assert(F(t,j) >= -%d && F(t,j) <= %d);' % (bound[boundpos][k],bound[boundpos][k]))
boundpos += boundmove
self.printline(' }')
self.printline('}')
def codegen_assertranges_inv(self,transformpos,bound,boundprogressions):
pass # XXX
def codegen_opt_reduce(self,start,stop,spacing,indexing,transformpos):
vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos)
if not vectorpos: return
return [('vector_reduce',v) for v in self.selectorder(vectorpos)]
def codegen_opt_reduce_ifforward(self,start,stop,spacing,indexing,transformpos):
vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos)
if not vectorpos: return
return [('vector_reduce_ifforward',v) for v in self.selectorder(vectorpos)]
def codegen_opt_reduce_ifreverse(self,start,stop,spacing,indexing,transformpos):
vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos)
if not vectorpos: return
return [('vector_reduce_ifreverse',v) for v in self.selectorder(vectorpos)]
def codegen_vector_reduce(self,v):
self.use('function','reduce_x16')
j = v//16
self.vector_ensureloaded(j)
v = self.vectorvar[j]
self.printline('%s%d = reduce_x16(%s%d,qdata);' % (v,j,v,j))
codegen_vector_reduce_inv = codegen_vector_reduce
def codegen_reduce(self,start,stop,spacing,indexing,transformpos):
self.vector_spillall()
if vectorization: self.codegen_comment('nttcompiler did not vectorize this')
self.use('function','reduce')
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos)))
self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing))
self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing))
self.printline(' int16 f0 = F(t,j);')
self.printline(' f0 = reduce(f0,qdata);')
self.printline(' F(t,j) = f0;')
self.printline(' }')
self.printline('}')
codegen_reduce_inv = codegen_reduce
def codegen_skip(self,*args):
pass
codegen_reduce_ifforward = codegen_reduce
codegen_reduce_ifforward_inv = codegen_skip
codegen_vector_reduce_ifforward = codegen_vector_reduce
codegen_vector_reduce_ifforward_inv = codegen_skip
codegen_reduce_ifreverse = codegen_skip
codegen_reduce_ifreverse_inv = codegen_reduce
codegen_vector_reduce_ifreverse = codegen_skip
codegen_vector_reduce_ifreverse_inv = codegen_vector_reduce
def codegen_opt_twist(self,start,stop,spacing,bm,be,indexing,transformpos):
vectorpos = self.vectorpositions(start,stop,spacing,indexing,transformpos)
if not vectorpos: return
result = []
for v in self.selectorder(vectorpos):
result += [('vector_twist',v,bm,be,*vectorpos[v])]
return result
def codegen_vector_twist(self,v,bm,be,*indices):
be %= bm
self.use('function','mulmod_scaled_x16')
self.use('int16x16','precomp',bm,be,*indices)
self.use('int16x16','qinvprecomp',bm,be,*indices)
precompname = 'precomp_%d_%d'%(bm,be)
for e in indices:
precompname += '_%s'%e
j = v//16
self.vector_ensureloaded(j)
v = self.vectorvar[j]
self.printline('%s%d = mulmod_scaled_x16(%s%d,%s,qinv%s,qdata);' % (v,j,v,j,precompname,precompname))
def codegen_vector_twist_inv(self,v,bm,be,*indices):
self.codegen_vector_twist(v,bm,-be,*indices)
def codegen_twist(self,start,stop,spacing,bm,be,indexing,transformpos):
self.vector_spillall()
if vectorization: self.codegen_comment('nttcompiler did not vectorize this')
be %= bm
self.use('function','mulmod_scaled')
self.use('int16*','scaledzeta_pow',bm,be,(stop-start)//spacing)
self.use('int16*','qinvscaledzeta_pow',bm,be,(stop-start)//spacing)
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos)))
self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing))
self.printline(' for (long long j = %d;j != %d;++j) {' % (0,(stop-start)//spacing))
self.printline(' int16 f0 = F(t,%d%+d*j);'% (start,spacing))
self.printline(' f0 = mulmod_scaled(f0,scaledzeta_pow_%d_%d[j],qinvscaledzeta_pow_%d_%d[j],qdata);'% (bm,be,bm,be))
self.printline(' F(t,%d%+d*j) = f0;'% (start,spacing))
self.printline(' }')
self.printline('}')
def codegen_twist_inv(self,start,stop,spacing,bm,be,indexing,transformpos):
self.codegen_twist(start,stop,spacing,bm,-be,indexing,transformpos)
def codegen_opt_butterfly(self,start,stop,spacing,offset,bm,be,indexing,transformpos):
if not vectorization: return
def F(t,j):
assert t >= 0
assert t < 1<<len(transformpos)
assert j >= 0
assert j < 1<<len(indexing)
result = 0
result += sum(((t>>i)&1)<<transformpos[i] for i in range(len(transformpos)))
result += sum(((j>>i)&1)<<indexing[i] for i in range(len(indexing)))
return result
todo = {}
todovector = {}
for t in range(1<<len(transformpos)):
for j in range(start,stop,spacing):
physical0 = F(t,j)
physical1 = F(t,j+offset)
assert physical0 not in todo
assert physical1 not in todo
todo[physical0] = (0,physical1)
todo[physical1] = (1,physical0)
physical0vector = physical0&~15
physical1vector = physical1&~15
# self.codegen_comment('physical butterfly %d,%d vectors %d,%d' % (physical0,physical1,physical0vector,physical1vector))
if physical0vector not in todovector:
todovector[physical0vector] = (0,physical1vector)
if physical1vector not in todovector:
todovector[physical1vector] = (1,physical0vector)
if todovector[physical0vector] != (0,physical1vector): return
if todovector[physical1vector] != (1,physical0vector): return
for v in todovector:
side,vx = todovector[v]
for i in range(16):
if v+i not in todo or todo[v+i] != (side,vx+i):
return
result = []
for v0 in self.selectorder(todovector):
side,v1 = todovector[v0]
if side != 0: continue
result += [('vector_butterfly',v0,v1,bm,be)]
return result
def codegen_vector_butterfly(self,v0,v1,bm,be):
self.use('function','add_x16')
self.use('function','sub_x16')
be %= bm
if be != 0:
self.use('function','mulmod_scaled_x16')
self.use('int16x16','scaledzeta_x16',bm,be)
self.use('int16x16','qinvscaledzeta_x16',bm,be)
j0 = v0//16
self.vector_ensureloaded(j0)
v0 = self.vectorvar[j0]
j1 = v1//16
self.vector_ensureloaded(j1)
v1 = self.vectorvar[j1]
if be != 0:
self.printline('%s%d = mulmod_scaled_x16(%s%d,scaledzeta_x16_%d_%d,qinvscaledzeta_x16_%d_%d,qdata);' % (v1,j1,v1,j1,bm,be,bm,be))
self.vector_nextvarname(j0)
self.vector_nextvarname(j1)
w0 = self.vectorvar[j0]
w1 = self.vectorvar[j1]
self.printline('int16x16 %s%d = add_x16(%s%d,%s%d);' % (w0,j0,v0,j0,v1,j1))
self.printline('int16x16 %s%d = sub_x16(%s%d,%s%d);' % (w1,j1,v0,j0,v1,j1))
def codegen_vector_butterfly_inv(self,v0,v1,bm,be):
self.use('function','add_x16')
self.use('function','sub_x16')
be = (-be)%bm
if be != 0:
self.use('function','mulmod_scaled_x16')
self.use('int16x16','scaledzeta_x16',bm,be)
self.use('int16x16','qinvscaledzeta_x16',bm,be)
j0 = v0//16
self.vector_ensureloaded(j0)
v0 = self.vectorvar[j0]
j1 = v1//16
self.vector_ensureloaded(j1)
v1 = self.vectorvar[j1]
self.vector_nextvarname(j0)
self.vector_nextvarname(j1)
w0 = self.vectorvar[j0]
w1 = self.vectorvar[j1]
self.printline('int16x16 %s%d = add_x16(%s%d,%s%d);' % (w0,j0,v0,j0,v1,j1))
self.printline('int16x16 %s%d = sub_x16(%s%d,%s%d);' % (w1,j1,v0,j0,v1,j1))
if be != 0:
self.printline('%s%d = mulmod_scaled_x16(%s%d,scaledzeta_x16_%d_%d,qinvscaledzeta_x16_%d_%d,qdata);' % (w1,j1,w1,j1,bm,be,bm,be))
def codegen_butterfly(self,start,stop,spacing,offset,bm,be,indexing,transformpos):
self.vector_spillall()
if vectorization: self.codegen_comment('nttcompiler did not vectorize this')
self.use('function','add')
self.use('function','sub')
if be != 0:
self.use('int16','scaledzeta',bm,be)
self.use('int16','qinvscaledzeta',bm,be)
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos)))
self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing))
self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing))
self.printline(' int16 f0 = F(t,j);')
self.printline(' int16 f1 = F(t,j%+d);' % offset)
if be != 0:
self.printline(' f1 = mulmod_scaled(f1,scaledzeta_%d_%d,qinvscaledzeta_%d_%d,qdata);' % (bm,be,bm,be))
self.printline(' F(t,j) = add(f0,f1);')
self.printline(' F(t,j%+d) = sub(f0,f1);' % offset)
self.printline(' }')
self.printline('}')
def codegen_butterfly_inv(self,start,stop,spacing,offset,bm,be,indexing,transformpos):
self.vector_spillall()
if vectorization: self.codegen_comment('nttcompiler did not vectorize this')
be = -be
be %= bm
self.use('function','add')
self.use('function','sub')
if be != 0:
self.use('int16','scaledzeta',bm,be)
self.use('int16','qinvscaledzeta',bm,be)
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,1<<len(transformpos)))
self.printline('for (long long t = 0;t < %d;++t) {' % (1<<len(transformpos)))
if unrollingpragma: self.printline_noindent('UNROLL(%d)' % min(16,(stop-start)//spacing))
self.printline(' for (long long j = %d;j != %d;j += %d) {' % (start,stop,spacing))
self.printline(' int16 f0 = F(t,j);')
self.printline(' int16 f1 = F(t,j%+d);' % offset)
self.printline(' F(t,j) = add(f0,f1);')
self.printline(' f1 = sub(f0,f1);')
if be != 0:
self.printline(' f1 = mulmod_scaled(f1,scaledzeta_%d_%d,qinvscaledzeta_%d_%d,qdata);' % (bm,be,bm,be))
self.printline(' F(t,j%+d) = f1;' % offset)
self.printline(' }')
self.printline('}')
def codegen(self):
N = self.initialN
self.indentlevel = 0
self.batchsize = N
self.infunction = False
self.inbatch = False
self.codegen_blankline()
self.codegen_comment('----- codegen pass 1')
self.codegen_comment('')
for c in self.code:
if c[0] == 'comment':
self.codegen_comment('// %s' % str(*c[1:]))
elif c[0] == 'blankline':
self.codegen_comment('')
elif c[0] == 'assertranges':
self.codegen_comment('assertranges ...')
else:
self.codegen_comment(' '.join(str(ci) for ci in c))
# ----- initial optimization of each instruction
newcode = []
for c in self.code:
if c[0] in ('comment','blankline'): continue
if c[0] == 'assertranges':
if not generatecassertions: continue
opt = 'codegen_opt_%s'%c[0]
if opt in type(self).__dict__:
opt = self.__getattribute__(opt)
opt = opt(*c[1:])
if opt:
newcode += opt
continue
newcode += [c]
self.code = newcode
# ----- eliminate maps if purely vectorized
purelyvectorized = True
for c in self.code:
if c[0] in ('startntt','stopntt','startbatch','stopbatch','doublereps'): continue
if c[0] in ('physical_unmap','physical_map'): continue
if c[0].startswith('vector_'): continue
purelyvectorized = False
break
if purelyvectorized:
newcode = []
for c in self.code:
if c[0] in ('physical_unmap','physical_map'): continue
newcode += [c]
self.code = newcode
# ----- reshuffle
vectorops = ('vector_butterfly','vector_reduce','vector_reduce_ifforward','vector_reduce_ifreverse','vector_twist','vector_permute')
for i in range(1,len(self.code)):
j = i
while j > 0:
wantswap = False
if self.code[j-1][0] in ('physical_unmap','physical_map'):
if self.code[j][0] in ('stopbatch','startbatch','doublereps'):
wantswap = True
if 0: # doesn't seem as good as the more ad-hoc partitioning below
if self.code[j-1][0] in vectorops:
if self.code[j][0] in vectorops:
inputs0 = self.code[j-1][1:3] if self.code[j-1][0] in ('vector_butterfly','vector_permute') else self.code[j-1][1:2]
inputs1 = self.code[j][1:3] if self.code[j][0] in ('vector_butterfly','vector_permute') else self.code[j][1:2]
if len(set(inputs0).intersection(set(inputs1))) == 0:
if bitreverse(inputs1[-1]) < bitreverse(inputs0[-1]):
wantswap = True
if not wantswap: break
self.code[j-1],self.code[j] = self.code[j],self.code[j-1]
j -= 1
# ----- try to partition vector instructions into big lanes
code = self.code
newcode = []
batchsize = self.initialN
while len(code) > 0:
for j in range(len(code)):
if code[j][0] not in vectorops:
break
if j == 0:
if code[0][0] == 'doublereps':
batchsize //= 2
newcode += code[:1]
code = code[1:]
else:
vectorcode,code = code[:j],code[j:]
okbiglane = 1
biglane = 1
while True:
ok = True
for c in vectorcode:
inputs = c[1:3] if c[0] in ('vector_butterfly','vector_permute') else c[1:2]
if len(set(i%biglane for i in inputs)) > 1:
ok = False
break
if not ok: break
okbiglane = biglane
biglane *= 2
if biglane > self.initialN: break
if okbiglane > 64: okbiglane = 1
for r in range(okbiglane):
if r != 0:
newcode += [('stopbatch',batchsize),('startbatch',batchsize)]
for c in vectorcode:
if c[1]%okbiglane == r:
newcode += [c]
self.code = newcode
# ----- peephole
progress = True
while progress:
progress = False
for i in range(2,len(self.code)):
if self.code[i-2][0] == 'physical_unmap':
if self.code[i-1][0] == 'physical_map':
if self.code[i][0] == 'physical_unmap':
self.code = self.code[:i-1]+self.code[i+1:]
progress = True
break
for i in range(1,len(self.code)):
if self.code[i-1][0] == 'startbatch':
if self.code[i][0] == 'stopbatch':
self.code = self.code[:i-1]+self.code[i+1:]
progress = True
break
# ----- done with pass 2
self.codegen_blankline()
self.codegen_comment('----- codegen pass 2')
self.codegen_comment('')
for c in self.code:
if c[0] == 'assertranges':
self.codegen_comment('assertranges ...')
else:
self.codegen_comment(' '.join(str(ci) for ci in c))
for c in self.code:
if c[0] == 'assertranges':
self.codegen_comment('assertranges ...')
else:
self.codegen_comment(' '.join(str(a) for a in c))
self.__getattribute__('codegen_'+c[0])(*c[1:])
# XXX: track ranges through inverse
for c in reversed(self.code):
if c[0] == 'assertranges':
self.codegen_comment('assertranges ...')
else:
self.codegen_comment('inv '+' '.join(str(a) for a in c))
self.__getattribute__('codegen_'+c[0]+'_inv')(*c[1:])
assert self.indentlevel == 0
assert not self.inbatch
assert not self.infunction
def physical_permute(self,*perm):
r"""
Reshuffle data.
Data at physical position sum_i 2^i p_i
moves to physical position sum_i 2^NEW(i) p_i
where NEW maps perm[0] to perm[1],
maps perm[1] to perm[2],
etc.,
maps perm[-1] to perm[0],
and fixes any other inputs.
"""
N = self.N
indexing = self.indexing
transformpos = self.transformpos
NEW = {}
for i in range(len(indexing)):
NEW[indexing[i]] = indexing[i]
for i in range(len(transformpos)):
NEW[transformpos[i]] = transformpos[i]
for i in range(len(perm)):
assert perm[i] in NEW
for i in range(len(perm)):
NEW[perm[i]] = perm[(i+1)%len(perm)]
assert set(indexing+transformpos) == set(NEW[i] for i in indexing).union(set(NEW[i] for i in transformpos))
# old virtual position sum_i 2^i v_i between 0 and N-1
# is stored at physical position sum_i 2^indexing[i] v_i
# and in a moment will be stored at
# physical position sum_i 2^NEW(indexing[i]) v_i
newindexing = [NEW[indexing[i]] for i in range(len(indexing))]
newtransformpos = [NEW[transformpos[i]] for i in range(len(transformpos))]
self.code += [('physical_permute',perm,tuple(indexing),tuple(transformpos),tuple(newindexing),tuple(newtransformpos))]
self.indexing = newindexing
self.transformpos = newtransformpos
def assertions(self):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert 1<<len(indexing) == N
self.code += [('comment','transform size %d' % N)]
self.code += [('comment','transform indexing %s' % indexing)]
self.code += [('comment','transforms per batch %d' % (1<<len(transformpos)))]
self.code += [('comment','batch indexing %s' % transformpos)]
self.code += [('comment','total batch size %d' % (N<<len(transformpos)))]
if generatecassertions:
self.use('int16','q')
for p in moduli:
self.code += p.assertions(N,transformpos)
def startbatch(self):
N = self.N
transformpos = self.transformpos
self.code += [('startbatch',N<<len(transformpos))]
def stopbatch(self):
N = self.N
transformpos = self.transformpos
self.code += [('stopbatch',N<<len(transformpos))]
def startntt(self):
self.code += [('startntt',self.initialN)]
self.startbatch()
self.code += [('comment','----- PRECONDITIONS')]
self.code += [('physical_map',tuple(self.indexing),tuple(self.transformpos))]
self.assertions()
def startlayer(self,layernum):
self.code += [('blankline',)]
self.code += [('comment','----- LAYER %d' % layernum)]
def stoplayer(self,layernum):
self.code += [('blankline',)]
self.code += [('comment','----- POSTCONDITIONS AFTER LAYER %d' % layernum)]
self.assertions()
def announce(self,functionname,args):
self.code += [('blankline',)]
self.code += [('comment','%s(%s)' % (functionname,','.join(str(a) for a in args)))]
def stopntt(self):
self.stopbatch()
self.code += [('physical_unmap',tuple(self.indexing),tuple(self.transformpos))]
self.code += [('stopntt',self.initialN)]
def vectorpositions(self,start,stop,spacing,indexing,transformpos):
if not vectorization: return False
if (stop-start)%spacing: return False
def F(t,j):
assert t >= 0
assert t < 1<<len(transformpos)
assert j >= 0
assert j < 1<<len(indexing)
result = 0
result += sum(((t>>i)&1)<<transformpos[i] for i in range(len(transformpos)))
result += sum(((j>>i)&1)<<indexing[i] for i in range(len(indexing)))
return result
emap = {}
for t in range(1<<len(transformpos)):
for e in range((stop-start)//spacing):
j = start+e*spacing
assert F(t,j) not in emap
emap[F(t,j)] = e
todo = sorted(emap)
result = {}
while len(todo) > 0:
v = todo[0]
if v%16 or len(todo)%16 or todo[:16] != [v+i for i in range(16)]:
return False
assert v not in result
result[v] = [emap[v+i] for i in range(16)]
todo = todo[16:]
return result
def reduce(self,start,stop,spacing):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert (stop-start)%spacing == 0
assert all(j >= 0 for j in range(start,stop,spacing))
assert all(j < N for j in range(start,stop,spacing))
self.code += [('reduce',start,stop,spacing,tuple(indexing),tuple(transformpos))]
newmoduli = ()
for p in moduli:
newmoduli += p.reduce(start,stop,spacing)
self.moduli = newmoduli
def reduce_ifforward(self,start,stop,spacing):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert (stop-start)%spacing == 0
assert all(j >= 0 for j in range(start,stop,spacing))
assert all(j < N for j in range(start,stop,spacing))
self.code += [('reduce_ifforward',start,stop,spacing,tuple(indexing),tuple(transformpos))]
newmoduli = ()
for p in moduli:
newmoduli += p.reduce(start,stop,spacing)
self.moduli = newmoduli
def reduce_ifreverse(self,start,stop,spacing):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
assert (stop-start)%spacing == 0
assert all(j >= 0 for j in range(start,stop,spacing))
assert all(j < N for j in range(start,stop,spacing))
self.code += [('reduce_ifreverse',start,stop,spacing,tuple(indexing),tuple(transformpos))]
def twist(self,start,stop,spacing,bm,be):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert (stop-start)%spacing == 0
be %= bm
assert all(j >= 0 for j in range(start,stop,spacing))
assert all(j < N for j in range(start,stop,spacing))
self.code += [('twist',start,stop,spacing,bm,be,tuple(indexing),tuple(transformpos))]
newmoduli = ()
for p in moduli:
newmoduli += p.twist(start,stop,spacing,bm,be)
self.moduli = newmoduli
def twists(self):
moduli = self.moduli
reducetodo = []
twisttodo = []
for M in moduli:
d = M.degree
startpos = M.pos[0]
assert list(M.pos) == list(range(startpos,startpos+d))
bm = M.root
be = M.rootpow
if be+be > bm: be -= bm
if be == 0:
reducetodo += [(startpos,startpos+d,1)]
else:
twisttodo += [(startpos,startpos+d,1,d*bm,be)]
for t in reducetodo: self.reduce(*t)
for t in twisttodo: self.twist(*t)
def butterfly(self,start,stop,spacing,offset,bm,be):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert (stop-start)%spacing == 0
be %= bm
assert all(j >= 0 for j in range(start,stop,spacing))
assert all(j < N for j in range(start,stop,spacing))
assert all(j+offset >= 0 for j in range(start,stop,spacing))
assert all(j+offset < N for j in range(start,stop,spacing))
self.code += [('butterfly',start,stop,spacing,offset,bm,be,tuple(indexing),tuple(transformpos))]
newmoduli = ()
for p in moduli:
newmoduli += p.butterfly(start,stop,spacing,offset,bm,be)
self.moduli = newmoduli
def butterflies(self):
moduli = self.moduli
todo = []
for M in moduli:
d = M.degree
assert d%2 == 0
startpos = M.pos[0]
assert list(M.pos) == list(range(startpos,startpos+d))
bm = M.root
be = M.rootpow
if be+be > bm: be -= bm
if be != 0: bm *= 2
todo += [(startpos,startpos+d//2,1,d//2,bm,be)]
for t in todo: self.butterfly(*t)
def fold(self,offset):
N = self.N
indexing = self.indexing
transformpos = self.transformpos
moduli = self.moduli
assert len(moduli)%2 == 0
self.code += [('physical_unmap',tuple(indexing),tuple(transformpos))]
newmoduli = ()
for i in range(len(moduli)//2):
p0 = moduli[i]
p1 = moduli[i+len(moduli)//2]
newmoduli += (p0.fold(p1,offset),)
moduli = newmoduli
transformpos += [indexing[-1]]
indexing = indexing[:-1]
N //= 2
transformpos.sort()
self.N = N
self.transformpos = transformpos
self.indexing = indexing
self.moduli = moduli
self.code += [('physical_map',tuple(indexing),tuple(transformpos))]
def nextbatch(self):
self.stopbatch()
self.startbatch()
def halfbatch(self):
# double reps
# reducing 2^(len(indexing)+len(transformpos)) transforms/batch
# to 2^(len(indexing)+len(transformpos)-1) transforms/batch
indexing = self.indexing
transformpos = self.transformpos
self.code += [('physical_unmap',tuple(indexing),tuple(transformpos))]
self.stopbatch()
assert len(indexing)+len(transformpos)-1 in transformpos
transformpos.remove(len(indexing)+len(transformpos)-1)
self.code += [('doublereps',)]
self.transformpos = transformpos
self.startbatch()
self.code += [('physical_map',tuple(indexing),tuple(transformpos))]
# ----- SOURCE PROGRAM
def doit():
strategy = [[]]
qlist = []
N = None
for line in sys.stdin:
line = line.strip().split()
if len(line) == 0:
if len(strategy[-1]) > 0:
strategy += [[]]
continue
if line[0] == 'N':
N = int(line[1])
continue
if line[0] == 'q':
qlist += [int(line[1])]
continue
strategy[-1] += [(line[0],tuple(map(int,line[1:])))]
if strategy[-1] == []:
strategy = strategy[:-1]
S = state(N,qlist)
S.startntt()
layernum = 1
for layer in strategy:
S.startlayer(layernum)
for functionname,args in layer:
S.announce(functionname,args)
S.__getattribute__(functionname)(*args)
S.stoplayer(layernum)
layernum += 1
S.stopntt()
S.codegen()
S.printused()
sys.stdout.write(S.output)
doit()