-rwxr-xr-x 808 nttcompiler-20220411/scripts/spec-interp
#!/usr/bin/env python3
import sys
inputs = int(sys.argv[1])
q = int(sys.argv[2])
points = [int(point) for point in sys.argv[3:]]
assert len(points) == inputs
eval = [[pow(point,i,q) for i in range(inputs)] for point in points]
# ----- somehow come up with scaled inverse matrix
interp = [[eval[i][(inputs-j)%inputs] for i in range(inputs)] for j in range(inputs)]
# ----- check that it's really a scaled inverse matrix
for i in range(inputs):
for k in range(inputs):
s = sum(eval[i][j]*interp[j][k] for j in range(inputs))
if i == k:
assert (s-inputs)%q == 0
else:
assert s%q == 0
# ----- generate output
for out in range(inputs):
result = 'remout_f_%d = '%out
result += ' + '.join(
'%d*in_f_%d'%(interp[out][i]%q,i)
for i in range(inputs)
)
print(result)