-rw-r--r-- 3348 nttcompiler-20220411/512/works.c
#include <stdlib.h> #include <assert.h> #include "ntt_512.h" typedef int16_t int16; #define ALIGN __attribute((aligned(32))) ALIGN int16 base[512]; ALIGN int16 M[512*512]; ALIGN int16 f[512*512]; ALIGN int16 g[512*512]; #define REPS 30 #define OFFSETS 128 ALIGN int16 h[REPS*512+OFFSETS]; long long qlist[2] = {7681,10753}; void (*nttlist[2])(int16*,long long) = {ntt_512_7681,ntt_512_10753}; void (*invnttlist[2])(int16*,long long) = {ntt_512_7681_inv,ntt_512_10753_inv}; int main() { for (long long qpos = 0;qpos < 2;++qpos) { long long q = qlist[qpos]; void (*ntt)(int16*,long long) = nttlist[qpos]; void (*invntt)(int16*,long long) = invnttlist[qpos]; // test that basis gives powers char seenbase[q]; for (long long i = 0;i < q;++i) seenbase[i] = 0; for (long long e = 0;e < 512;++e) { for (long long j = 0;j < 512;++j) M[e*512+j] = 0; M[e*512+e] = 1; } ntt(M,512); for (long long j = 0;j < 512;++j) { long long z = M[1*512+j]; z %= q; if (z < 0) z += q; assert(z >= 0); assert(z < q); seenbase[z] = 1; long long ze = 1; for (long long e = 0;e < 512;++e) { assert((ze-M[e*512+j])%q == 0); ze *= z; ze %= q; if (ze < 0) ze += q; } } // test that powers are of 512th roots of 1 char root[10][q]; for (long long r = 0;r < 10;++r) { if (r == 0) for (long long i = 0;i < q;++i) root[r][i] = i == 1; else for (long long i = 0;i < q;++i) { long long ii = i*i % q; assert(ii >= 0); assert(ii < q); assert(r-1 >= 0); root[r][i] = root[r-1][ii]; } } for (long long i = 0;i < q;++i) assert(root[9][i] == seenbase[i]); // test that some random examples pass bounds checks and linearity checks // XXX: rethink how input bounds should be selected here for (long long j = 0;j < 512*512;++j) { M[j] %= q; M[j] += q; M[j] %= q; if (random()&1) M[j] -= q; if (M[j] > 8000) M[j] -= q; if (M[j] < -8000) M[j] += q; } for (long long loop = 0;loop < 30;++loop) { for (long long j = 0;j < 512*512;++j) g[j] = f[j] = (random()%q)-(q/2); ntt(g,512); if (loop == 0) for (long long e = 0;e < 512;++e) for (long long j = 0;j < 512;++j) { long long s = 0; for (long long i = 0;i < 512;++i) s += f[e*512+i]*(long long) M[i*512+j]; assert((s-g[e*512+j])%q == 0); } } // test that inverse gives identity invntt(M,512); for (long long e = 0;e < 512;++e) { for (long long j = 0;j < 512;++j) if (j != e) assert(M[e*512+j]%q == 0); assert(((M[e*512+e]%q)+q)%q == 512); } // test for consistency across reps and alignment for (long long reps = 1;reps < REPS;++reps) { if (reps > 512) continue; for (long long offset = 0;offset < OFFSETS;++offset) { for (long long j = 0;j < 512*reps;++j) g[j] = h[offset+j] = (random()%q)-(q/2); for (long long j = 0;j < 512*reps;j += 512) ntt(g+j,1); ntt(h+offset,reps); for (long long j = 0;j < 512*reps;++j) assert(g[j] == h[offset+j]); } } } return 0; }