From e3f64bfd1e36b50b11d0357f0f54ef0384284a8a Mon Sep 17 00:00:00 2001 From: Logan oos Even <46396513+Logan007@users.noreply.github.com> Date: Sun, 11 Oct 2020 15:28:04 +0545 Subject: [PATCH] AES-NI speed-up (#459) * converted cbc loops more into sse * added a 2nd parallel cbc decryption (aes-ni) * transformed while into for * offering 4 parallel aes-ni rails Co-authored-by: Logan007 --- src/aes.c | 198 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 172 insertions(+), 26 deletions(-) diff --git a/src/aes.c b/src/aes.c index 32dc577..4698e36 100644 --- a/src/aes.c +++ b/src/aes.c @@ -365,48 +365,194 @@ int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t } -#define fix_xor(target, source) *(uint32_t*)&(target)[0] = *(uint32_t*)&(target)[0] ^ *(uint32_t*)&(source)[0]; *(uint32_t*)&(target)[4] = *(uint32_t*)&(target)[4] ^ *(uint32_t*)&(source)[4]; \ - *(uint32_t*)&(target)[8] = *(uint32_t*)&(target)[8] ^ *(uint32_t*)&(source)[8]; *(uint32_t*)&(target)[12] = *(uint32_t*)&(target)[12] ^ *(uint32_t*)&(source)[12]; - - int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - uint8_t tmp[AES_BLOCK_SIZE]; - size_t i; - size_t n; + int n; // number of blocks + int ret = (int)in_len & 15; // remainder + + __m128i ivec = _mm_loadu_si128((__m128i*)iv); + + for(n = in_len / 16; n != 0; n--) { + __m128i tmp = _mm_loadu_si128((__m128i*)in); + in += 16; + tmp = _mm_xor_si128(tmp, ivec); + + tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]); + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]); + } + } + tmp = _mm_aesenclast_si128(tmp, ctx->rk_enc[ctx->Nr]); - memcpy(tmp, iv, AES_BLOCK_SIZE); + ivec = tmp; - n = in_len / AES_BLOCK_SIZE; - for(i=0; i < n; i++) { - fix_xor(tmp, &in[i * AES_BLOCK_SIZE]); - aes_internal_encrypt(ctx, tmp, tmp); - memcpy(&out[i * AES_BLOCK_SIZE], tmp, AES_BLOCK_SIZE); + _mm_storeu_si128((__m128i*)out, tmp); + out += 16; } - return n * AES_BLOCK_SIZE; + + return ret; } int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - uint8_t tmp[AES_BLOCK_SIZE]; - uint8_t old[AES_BLOCK_SIZE]; - size_t i; - size_t n; + int n; // number of blocks + int ret = (int)in_len & 15; // remainder - memcpy(tmp, iv, AES_BLOCK_SIZE); + __m128i ivec = _mm_loadu_si128((__m128i*)iv); - n = in_len / AES_BLOCK_SIZE; - for(i=0; i < n; i++) { - memcpy(old, &in[i * AES_BLOCK_SIZE], AES_BLOCK_SIZE); - aes_internal_decrypt(ctx, &in[i * AES_BLOCK_SIZE], &out[i * AES_BLOCK_SIZE]); - fix_xor(&out[i * AES_BLOCK_SIZE], tmp); - memcpy(tmp, old, AES_BLOCK_SIZE); + for(n = in_len / 16; n > 3; n -=4) { + __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp3 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp4 = _mm_loadu_si128((__m128i*)in); in += 16; + + __m128i old_in1 = tmp1; + __m128i old_in2 = tmp2; + __m128i old_in3 = tmp3; + __m128i old_in4 = tmp4; + + tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); + tmp3 = _mm_xor_si128 (tmp3, ctx->rk_dec[ 0]); tmp4 = _mm_xor_si128 (tmp4, ctx->rk_dec[ 0]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 1]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 1]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 2]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 2]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 3]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 3]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 4]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 4]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 5]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 5]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 6]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 6]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 7]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 7]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 8]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 8]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 9]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 9]); + + if(ctx->Nr > 10) { + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[10]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[10]); + + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[11]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[11]); + + if(ctx->Nr > 12) { + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); + tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[12]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[12]); + + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); + tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[13]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[13]); + } + } + tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); + tmp3 = _mm_aesdeclast_si128(tmp3, ctx->rk_enc[ 0]); tmp4 = _mm_aesdeclast_si128(tmp4, ctx->rk_enc[ 0]); + + tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); + tmp3 = _mm_xor_si128 (tmp3, old_in2); tmp4 = _mm_xor_si128 (tmp4, old_in3); + + ivec = old_in4; + + _mm_storeu_si128((__m128i*) out, tmp1); out += 16; + _mm_storeu_si128((__m128i*) out, tmp2); out += 16; + _mm_storeu_si128((__m128i*) out, tmp3); out += 16; + _mm_storeu_si128((__m128i*) out, tmp4); out += 16; + } // now: less than 4 blocks remaining + + if(n > 1) { // 2 or 3 blocks remaining --> this code handles two of them + n-= 2; + + __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; + + __m128i old_in1 = tmp1; + __m128i old_in2 = tmp2; + + tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); + if(ctx->Nr > 10) { + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); + if(ctx->Nr > 12) { + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); + } + } + tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); + + tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); + + ivec = old_in2; + + _mm_storeu_si128((__m128i*) out, tmp1); out += 16; + _mm_storeu_si128((__m128i*) out, tmp2); out += 16; } - return n * AES_BLOCK_SIZE; + if(n) { // one block remaining + __m128i tmp = _mm_loadu_si128((__m128i*)in); + __m128i old_in = tmp; + + tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]); + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]); + } + } + tmp = _mm_aesdeclast_si128(tmp, ctx->rk_enc[ 0]); + + tmp = _mm_xor_si128 (tmp, ivec); + + _mm_storeu_si128((__m128i*) out, tmp); + } + + return ret; }