Browse Source

AES-NI speed-up (#459)

* converted cbc loops more into sse

* added a 2nd parallel cbc decryption (aes-ni)

* transformed while into for

* offering 4 parallel aes-ni rails

Co-authored-by: Logan007 <you@example.com>
pull/461/head
Logan oos Even 4 years ago
committed by GitHub
parent
commit
e3f64bfd1e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 198
      src/aes.c

198
src/aes.c

@ -365,48 +365,194 @@ int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t
} }
#define fix_xor(target, source) *(uint32_t*)&(target)[0] = *(uint32_t*)&(target)[0] ^ *(uint32_t*)&(source)[0]; *(uint32_t*)&(target)[4] = *(uint32_t*)&(target)[4] ^ *(uint32_t*)&(source)[4]; \
*(uint32_t*)&(target)[8] = *(uint32_t*)&(target)[8] ^ *(uint32_t*)&(source)[8]; *(uint32_t*)&(target)[12] = *(uint32_t*)&(target)[12] ^ *(uint32_t*)&(source)[12];
int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len,
const unsigned char *iv, aes_context_t *ctx) { const unsigned char *iv, aes_context_t *ctx) {
uint8_t tmp[AES_BLOCK_SIZE]; int n; // number of blocks
size_t i; int ret = (int)in_len & 15; // remainder
size_t n;
memcpy(tmp, iv, AES_BLOCK_SIZE); __m128i ivec = _mm_loadu_si128((__m128i*)iv);
n = in_len / AES_BLOCK_SIZE; for(n = in_len / 16; n != 0; n--) {
for(i=0; i < n; i++) { __m128i tmp = _mm_loadu_si128((__m128i*)in);
fix_xor(tmp, &in[i * AES_BLOCK_SIZE]); in += 16;
aes_internal_encrypt(ctx, tmp, tmp); tmp = _mm_xor_si128(tmp, ivec);
memcpy(&out[i * AES_BLOCK_SIZE], tmp, AES_BLOCK_SIZE);
tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]);
if(ctx->Nr > 10) {
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]);
if(ctx->Nr > 12) {
tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]);
tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]);
} }
return n * AES_BLOCK_SIZE; }
tmp = _mm_aesenclast_si128(tmp, ctx->rk_enc[ctx->Nr]);
ivec = tmp;
_mm_storeu_si128((__m128i*)out, tmp);
out += 16;
}
return ret;
} }
int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len, int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len,
const unsigned char *iv, aes_context_t *ctx) { const unsigned char *iv, aes_context_t *ctx) {
uint8_t tmp[AES_BLOCK_SIZE]; int n; // number of blocks
uint8_t old[AES_BLOCK_SIZE]; int ret = (int)in_len & 15; // remainder
size_t i;
size_t n;
memcpy(tmp, iv, AES_BLOCK_SIZE); __m128i ivec = _mm_loadu_si128((__m128i*)iv);
n = in_len / AES_BLOCK_SIZE; for(n = in_len / 16; n > 3; n -=4) {
for(i=0; i < n; i++) { __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16;
memcpy(old, &in[i * AES_BLOCK_SIZE], AES_BLOCK_SIZE); __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16;
aes_internal_decrypt(ctx, &in[i * AES_BLOCK_SIZE], &out[i * AES_BLOCK_SIZE]); __m128i tmp3 = _mm_loadu_si128((__m128i*)in); in += 16;
fix_xor(&out[i * AES_BLOCK_SIZE], tmp); __m128i tmp4 = _mm_loadu_si128((__m128i*)in); in += 16;
memcpy(tmp, old, AES_BLOCK_SIZE);
__m128i old_in1 = tmp1;
__m128i old_in2 = tmp2;
__m128i old_in3 = tmp3;
__m128i old_in4 = tmp4;
tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]);
tmp3 = _mm_xor_si128 (tmp3, ctx->rk_dec[ 0]); tmp4 = _mm_xor_si128 (tmp4, ctx->rk_dec[ 0]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 1]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 1]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 2]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 2]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 3]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 3]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 4]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 4]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 5]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 5]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 6]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 6]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 7]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 7]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 8]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 8]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 9]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 9]);
if(ctx->Nr > 10) {
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[10]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[10]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]);
tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[11]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[11]);
if(ctx->Nr > 12) {
tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]);
tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[12]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[12]);
tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]);
tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[13]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[13]);
} }
}
tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]);
tmp3 = _mm_aesdeclast_si128(tmp3, ctx->rk_enc[ 0]); tmp4 = _mm_aesdeclast_si128(tmp4, ctx->rk_enc[ 0]);
return n * AES_BLOCK_SIZE; tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1);
tmp3 = _mm_xor_si128 (tmp3, old_in2); tmp4 = _mm_xor_si128 (tmp4, old_in3);
ivec = old_in4;
_mm_storeu_si128((__m128i*) out, tmp1); out += 16;
_mm_storeu_si128((__m128i*) out, tmp2); out += 16;
_mm_storeu_si128((__m128i*) out, tmp3); out += 16;
_mm_storeu_si128((__m128i*) out, tmp4); out += 16;
} // now: less than 4 blocks remaining
if(n > 1) { // 2 or 3 blocks remaining --> this code handles two of them
n-= 2;
__m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16;
__m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16;
__m128i old_in1 = tmp1;
__m128i old_in2 = tmp2;
tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]);
if(ctx->Nr > 10) {
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]);
tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]);
if(ctx->Nr > 12) {
tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]);
tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]);
}
}
tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]);
tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1);
ivec = old_in2;
_mm_storeu_si128((__m128i*) out, tmp1); out += 16;
_mm_storeu_si128((__m128i*) out, tmp2); out += 16;
}
if(n) { // one block remaining
__m128i tmp = _mm_loadu_si128((__m128i*)in);
__m128i old_in = tmp;
tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]);
if(ctx->Nr > 10) {
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]);
if(ctx->Nr > 12) {
tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]);
tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]);
}
}
tmp = _mm_aesdeclast_si128(tmp, ctx->rk_enc[ 0]);
tmp = _mm_xor_si128 (tmp, ivec);
_mm_storeu_si128((__m128i*) out, tmp);
}
return ret;
} }

Loading…
Cancel
Save