diff --git a/include/aes.h b/include/aes.h index f1c4e57..7a1e37d 100644 --- a/include/aes.h +++ b/include/aes.h @@ -37,39 +37,39 @@ #define AES128_KEY_BYTES (128/8) -#if defined (HAVE_OPENSSL_1_1) // openSSL 1.1 --------------------------------------------- +#if defined (HAVE_OPENSSL_1_1) // openSSL 1.1 --------------------------------------------------------------------- #include #include #include typedef struct aes_context_t { - EVP_CIPHER_CTX *enc_ctx; /* openssl's reusable evp_* en/de-cryption context */ - EVP_CIPHER_CTX *dec_ctx; /* openssl's reusable evp_* en/de-cryption context */ - const EVP_CIPHER *cipher; /* cipher to use: e.g. EVP_aes_128_cbc */ - uint8_t key[AES256_KEY_BYTES]; /* the pure key data for payload encryption & decryption */ - AES_KEY ecb_dec_key; /* one step ecb decryption key */ + EVP_CIPHER_CTX *enc_ctx; /* openssl's reusable evp_* en/de-cryption context */ + EVP_CIPHER_CTX *dec_ctx; /* openssl's reusable evp_* en/de-cryption context */ + const EVP_CIPHER *cipher; /* cipher to use: e.g. EVP_aes_128_cbc */ + uint8_t key[AES256_KEY_BYTES]; /* the pure key data for payload encryption & decryption */ + AES_KEY ecb_dec_key; /* one step ecb decryption key */ } aes_context_t; -#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI --------------------------- +#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI --------------------------------------------------- #include typedef struct aes_context_t { - __m128i rk_enc[15]; - __m128i rk_dec[15]; - int Nr; + __m128i rk_enc[15]; + __m128i rk_dec[15]; + int Nr; } aes_context_t; -#else // plain C -------------------------------------------------------------------------- +#else // plain C -------------------------------------------------------------------------------------------------- typedef struct aes_context_t { - uint32_t enc_rk[60]; // round keys for encryption - uint32_t dec_rk[60]; // round keys for decryption - int Nr; // number of rounds + uint32_t enc_rk[60]; // round keys for encryption + uint32_t dec_rk[60]; // round keys for decryption + int Nr; // number of rounds } aes_context_t; -#endif // --------------------------------------------------------------------------------- +#endif // --------------------------------------------------------------------------------------------------------- int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, diff --git a/src/aes.c b/src/aes.c index 864a1b1..510d674 100644 --- a/src/aes.c +++ b/src/aes.c @@ -20,148 +20,151 @@ #include "n2n.h" -#if defined (HAVE_OPENSSL_1_1) // openSSL 1.1 --------------------------------------------- +#if defined (HAVE_OPENSSL_1_1) // openSSL 1.1 --------------------------------------------------------------------- // get any erorr message out of openssl // taken from https://en.wikibooks.org/wiki/OpenSSL/Error_handling static char *openssl_err_as_string (void) { - BIO *bio = BIO_new (BIO_s_mem ()); - ERR_print_errors (bio); - char *buf = NULL; - size_t len = BIO_get_mem_data (bio, &buf); - char *ret = (char *) calloc (1, 1 + len); + BIO *bio = BIO_new (BIO_s_mem ()); + ERR_print_errors (bio); + char *buf = NULL; + size_t len = BIO_get_mem_data (bio, &buf); + char *ret = (char *) calloc (1, 1 + len); - if(ret) - memcpy (ret, buf, len); + if(ret) + memcpy (ret, buf, len); - BIO_free (bio); - return ret; + BIO_free (bio); + + return ret; } int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - int evp_len; - int evp_ciphertext_len; - - if(1 == EVP_EncryptInit_ex(ctx->enc_ctx, ctx->cipher, NULL, ctx->key, iv)) { - if(1 == EVP_CIPHER_CTX_set_padding(ctx->enc_ctx, 0)) { - if(1 == EVP_EncryptUpdate(ctx->enc_ctx, out, &evp_len, in, in_len)) { - evp_ciphertext_len = evp_len; - if(1 == EVP_EncryptFinal_ex(ctx->enc_ctx, out + evp_len, &evp_len)) { - evp_ciphertext_len += evp_len; - if(evp_ciphertext_len != in_len) - traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl encryption: encrypted %u bytes where %u were expected", - evp_ciphertext_len, in_len); + int evp_len; + int evp_ciphertext_len; + + if(1 == EVP_EncryptInit_ex(ctx->enc_ctx, ctx->cipher, NULL, ctx->key, iv)) { + if(1 == EVP_CIPHER_CTX_set_padding(ctx->enc_ctx, 0)) { + if(1 == EVP_EncryptUpdate(ctx->enc_ctx, out, &evp_len, in, in_len)) { + evp_ciphertext_len = evp_len; + if(1 == EVP_EncryptFinal_ex(ctx->enc_ctx, out + evp_len, &evp_len)) { + evp_ciphertext_len += evp_len; + if(evp_ciphertext_len != in_len) + traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl encryption: encrypted %u bytes where %u were expected", + evp_ciphertext_len, in_len); + } else + traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl final encryption: %s", + openssl_err_as_string()); + } else + traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl encrpytion: %s", + openssl_err_as_string()); } else - traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl final encryption: %s", - openssl_err_as_string()); - } else - traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl encrpytion: %s", - openssl_err_as_string()); + traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl padding setup: %s", + openssl_err_as_string()); } else - traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl padding setup: %s", - openssl_err_as_string()); - } else - traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl init: %s", - openssl_err_as_string()); + traceEvent(TRACE_ERROR, "aes_cbc_encrypt openssl init: %s", + openssl_err_as_string()); - EVP_CIPHER_CTX_reset(ctx->enc_ctx); + EVP_CIPHER_CTX_reset(ctx->enc_ctx); - return 0; + return 0; } int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - int evp_len; - int evp_plaintext_len; - - if(1 == EVP_DecryptInit_ex(ctx->dec_ctx, ctx->cipher, NULL, ctx->key, iv)) { - if(1 == EVP_CIPHER_CTX_set_padding(ctx->dec_ctx, 0)) { - if(1 == EVP_DecryptUpdate(ctx->dec_ctx, out, &evp_len, in, in_len)) { - evp_plaintext_len = evp_len; - if(1 == EVP_DecryptFinal_ex(ctx->dec_ctx, out + evp_len, &evp_len)) { - evp_plaintext_len += evp_len; - if(evp_plaintext_len != in_len) - traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl decryption: decrypted %u bytes where %u were expected", - evp_plaintext_len, in_len); + int evp_len; + int evp_plaintext_len; + + if(1 == EVP_DecryptInit_ex(ctx->dec_ctx, ctx->cipher, NULL, ctx->key, iv)) { + if(1 == EVP_CIPHER_CTX_set_padding(ctx->dec_ctx, 0)) { + if(1 == EVP_DecryptUpdate(ctx->dec_ctx, out, &evp_len, in, in_len)) { + evp_plaintext_len = evp_len; + if(1 == EVP_DecryptFinal_ex(ctx->dec_ctx, out + evp_len, &evp_len)) { + evp_plaintext_len += evp_len; + if(evp_plaintext_len != in_len) + traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl decryption: decrypted %u bytes where %u were expected", + evp_plaintext_len, in_len); + } else + traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl final decryption: %s", + openssl_err_as_string()); + } else + traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl decrpytion: %s", + openssl_err_as_string()); } else - traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl final decryption: %s", - openssl_err_as_string()); - } else - traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl decrpytion: %s", - openssl_err_as_string()); + traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl padding setup: %s", + openssl_err_as_string()); } else - traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl padding setup: %s", - openssl_err_as_string()); - } else - traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl init: %s", - openssl_err_as_string()); + traceEvent(TRACE_ERROR, "aes_cbc_decrypt openssl init: %s", + openssl_err_as_string()); - EVP_CIPHER_CTX_reset(ctx->dec_ctx); + EVP_CIPHER_CTX_reset(ctx->dec_ctx); - return 0; + return 0; } int aes_ecb_decrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) { - AES_ecb_encrypt(in, out, &(ctx->ecb_dec_key), AES_DECRYPT); + AES_ecb_encrypt(in, out, &(ctx->ecb_dec_key), AES_DECRYPT); - return 0; + return 0; } int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { - // allocate context... - *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); - if (!(*ctx)) - return -1; - // ...and fill her up - - // initialize data structures - if(!((*ctx)->enc_ctx = EVP_CIPHER_CTX_new())) { - traceEvent(TRACE_ERROR, "aes_init openssl's evp_* encryption context creation failed: %s", - openssl_err_as_string()); - return -1; - } - if(!((*ctx)->dec_ctx = EVP_CIPHER_CTX_new())) { - traceEvent(TRACE_ERROR, "aes_init openssl's evp_* decryption context creation failed: %s", - openssl_err_as_string()); - return -1; - } - - // check key size and make key size (given in bytes) dependant settings - switch(key_size) { - case AES128_KEY_BYTES: // 128 bit key size - (*ctx)->cipher = EVP_aes_128_cbc(); - break; - case AES192_KEY_BYTES: // 192 bit key size - (*ctx)->cipher = EVP_aes_192_cbc(); - break; - case AES256_KEY_BYTES: // 256 bit key size - (*ctx)->cipher = EVP_aes_256_cbc(); - break; - default: - traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); - return -1; - } - - // key materiel handling - memcpy((*ctx)->key, key, key_size); - AES_set_decrypt_key(key, key_size * 8, &((*ctx)->ecb_dec_key)); - - return 0; + // allocate context... + *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); + if(!(*ctx)) + return -1; + + // ...and fill her up: + + // initialize data structures + if(!((*ctx)->enc_ctx = EVP_CIPHER_CTX_new())) { + traceEvent(TRACE_ERROR, "aes_init openssl's evp_* encryption context creation failed: %s", + openssl_err_as_string()); + return -1; + } + + if(!((*ctx)->dec_ctx = EVP_CIPHER_CTX_new())) { + traceEvent(TRACE_ERROR, "aes_init openssl's evp_* decryption context creation failed: %s", + openssl_err_as_string()); + return -1; + } + + // check key size and make key size (given in bytes) dependant settings + switch(key_size) { + case AES128_KEY_BYTES: // 128 bit key size + (*ctx)->cipher = EVP_aes_128_cbc(); + break; + case AES192_KEY_BYTES: // 192 bit key size + (*ctx)->cipher = EVP_aes_192_cbc(); + break; + case AES256_KEY_BYTES: // 256 bit key size + (*ctx)->cipher = EVP_aes_256_cbc(); + break; + default: + traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); + return -1; + } + + // key materiel handling + memcpy((*ctx)->key, key, key_size); + AES_set_decrypt_key(key, key_size * 8, &((*ctx)->ecb_dec_key)); + + return 0; } -#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI --------------------------- +#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI --------------------------------------------------- // inspired by https://gist.github.com/acapola/d5b940da024080dfaf5f @@ -172,176 +175,190 @@ int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { static __m128i aes128_keyexpand(__m128i key, __m128i keygened, uint8_t shuf) { - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - // unfortunately, shuffle expects immediate argument ... macrorize???!!! - switch (shuf) { - case 0x55: - keygened = _mm_shuffle_epi32(keygened, 0x55 ); - break; - case 0xaa: - keygened = _mm_shuffle_epi32(keygened, 0xaa ); - break; - case 0xff: - keygened = _mm_shuffle_epi32(keygened, 0xff ); - break; - default: - break; - } - return _mm_xor_si128(key, keygened); + + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + + // unfortunately, shuffle expects immediate argument, thus the not-so-stylish switch ... + // REVISIT: either macrorize this whole function (and perhaps the following one) or + // use shuffle_epi8 (which would require SSSE3 instead of SSE2) + switch(shuf) { + case 0x55: + keygened = _mm_shuffle_epi32(keygened, 0x55 ); + break; + case 0xaa: + keygened = _mm_shuffle_epi32(keygened, 0xaa ); + break; + case 0xff: + keygened = _mm_shuffle_epi32(keygened, 0xff ); + break; + default: + break; + } + + return _mm_xor_si128(key, keygened); } -static __m128i aes192_keyexpand_2(__m128i key, __m128i key2) -{ +static __m128i aes192_keyexpand_2(__m128i key, __m128i key2) { + key = _mm_shuffle_epi32(key, 0xff); key2 = _mm_xor_si128(key2, _mm_slli_si128(key2, 4)); + return _mm_xor_si128(key, key2); } -#define KEYEXP128(K, I) aes128_keyexpand(K, _mm_aeskeygenassist_si128(K, I), 0xff) -#define KEYEXP192(K1, K2, I) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, I), 0x55) +#define KEYEXP128(K, I) aes128_keyexpand (K, _mm_aeskeygenassist_si128(K, I), 0xff) +#define KEYEXP192(K1, K2, I) aes128_keyexpand (K1, _mm_aeskeygenassist_si128(K2, I), 0x55) #define KEYEXP192_2(K1, K2) aes192_keyexpand_2(K1, K2) -#define KEYEXP256(K1, K2, I) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, I), 0xff) -#define KEYEXP256_2(K1, K2) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, 0x00), 0xaa) +#define KEYEXP256(K1, K2, I) aes128_keyexpand (K1, _mm_aeskeygenassist_si128(K2, I), 0xff) +#define KEYEXP256_2(K1, K2) aes128_keyexpand (K1, _mm_aeskeygenassist_si128(K2, 0x00), 0xaa) // key setup static int aes_internal_key_setup (aes_context_t *ctx, const uint8_t *key, int key_bits) { - // number of rounds - ctx->Nr = 6 + (key_bits / 32); - - // encryption keys - switch (key_bits) { - case 128: { - ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*)key); - ctx->rk_enc[1] = KEYEXP128(ctx->rk_enc[0], 0x01); - ctx->rk_enc[2] = KEYEXP128(ctx->rk_enc[1], 0x02); - ctx->rk_enc[3] = KEYEXP128(ctx->rk_enc[2], 0x04); - ctx->rk_enc[4] = KEYEXP128(ctx->rk_enc[3], 0x08); - ctx->rk_enc[5] = KEYEXP128(ctx->rk_enc[4], 0x10); - ctx->rk_enc[6] = KEYEXP128(ctx->rk_enc[5], 0x20); - ctx->rk_enc[7] = KEYEXP128(ctx->rk_enc[6], 0x40); - ctx->rk_enc[8] = KEYEXP128(ctx->rk_enc[7], 0x80); - ctx->rk_enc[9] = KEYEXP128(ctx->rk_enc[8], 0x1B); - ctx->rk_enc[10] = KEYEXP128(ctx->rk_enc[9], 0x36); - break; - } - case 192: { - __m128i temp[2]; - ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*) key); - ctx->rk_enc[1] = _mm_loadu_si128((const __m128i*) (key+16)); - temp[0] = KEYEXP192(ctx->rk_enc[0], ctx->rk_enc[1], 0x01); - temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[1]); - ctx->rk_enc[1] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[1], (__m128d)temp[0], 0); - ctx->rk_enc[2] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); - ctx->rk_enc[3] = KEYEXP192(temp[0], temp[1], 0x02); - ctx->rk_enc[4] = KEYEXP192_2(ctx->rk_enc[3], temp[1]); - temp[0] = KEYEXP192(ctx->rk_enc[3], ctx->rk_enc[4], 0x04); - temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[4]); - ctx->rk_enc[4] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[4], (__m128d)temp[0], 0); - ctx->rk_enc[5] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); - ctx->rk_enc[6] = KEYEXP192(temp[0], temp[1], 0x08); - ctx->rk_enc[7] = KEYEXP192_2(ctx->rk_enc[6], temp[1]); - temp[0] = KEYEXP192(ctx->rk_enc[6], ctx->rk_enc[7], 0x10); - temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[7]); - ctx->rk_enc[7] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[7], (__m128d)temp[0], 0); - ctx->rk_enc[8] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); - ctx->rk_enc[9] = KEYEXP192(temp[0], temp[1], 0x20); - ctx->rk_enc[10] = KEYEXP192_2(ctx->rk_enc[9], temp[1]); - temp[0] = KEYEXP192(ctx->rk_enc[9], ctx->rk_enc[10], 0x40); - temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[10]); - ctx->rk_enc[10] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[10], (__m128d) temp[0], 0); - ctx->rk_enc[11] = (__m128i)_mm_shuffle_pd((__m128d)temp[0],(__m128d) temp[1], 1); - ctx->rk_enc[12] = KEYEXP192(temp[0], temp[1], 0x80); - break; - } - case 256: { - ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*) key); - ctx->rk_enc[1] = _mm_loadu_si128((const __m128i*) (key+16)); - ctx->rk_enc[2] = KEYEXP256(ctx->rk_enc[0], ctx->rk_enc[1], 0x01); - ctx->rk_enc[3] = KEYEXP256_2(ctx->rk_enc[1], ctx->rk_enc[2]); - ctx->rk_enc[4] = KEYEXP256(ctx->rk_enc[2], ctx->rk_enc[3], 0x02); - ctx->rk_enc[5] = KEYEXP256_2(ctx->rk_enc[3], ctx->rk_enc[4]); - ctx->rk_enc[6] = KEYEXP256(ctx->rk_enc[4], ctx->rk_enc[5], 0x04); - ctx->rk_enc[7] = KEYEXP256_2(ctx->rk_enc[5], ctx->rk_enc[6]); - ctx->rk_enc[8] = KEYEXP256(ctx->rk_enc[6], ctx->rk_enc[7], 0x08); - ctx->rk_enc[9] = KEYEXP256_2(ctx->rk_enc[7], ctx->rk_enc[8]); - ctx->rk_enc[10] = KEYEXP256(ctx->rk_enc[8], ctx->rk_enc[9], 0x10); - ctx->rk_enc[11] = KEYEXP256_2(ctx->rk_enc[9], ctx->rk_enc[10]); - ctx->rk_enc[12] = KEYEXP256(ctx->rk_enc[10], ctx->rk_enc[11], 0x20); - ctx->rk_enc[13] = KEYEXP256_2(ctx->rk_enc[11], ctx->rk_enc[12]); - ctx->rk_enc[14] = KEYEXP256(ctx->rk_enc[12], ctx->rk_enc[13], 0x40); - break; + // number of rounds + ctx->Nr = 6 + (key_bits / 32); + + // encryption keys + switch(key_bits) { + case 128: { + ctx->rk_enc[ 0] = _mm_loadu_si128((const __m128i*)key); + ctx->rk_enc[ 1] = KEYEXP128(ctx->rk_enc[0], 0x01); + ctx->rk_enc[ 2] = KEYEXP128(ctx->rk_enc[1], 0x02); + ctx->rk_enc[ 3] = KEYEXP128(ctx->rk_enc[2], 0x04); + ctx->rk_enc[ 4] = KEYEXP128(ctx->rk_enc[3], 0x08); + ctx->rk_enc[ 5] = KEYEXP128(ctx->rk_enc[4], 0x10); + ctx->rk_enc[ 6] = KEYEXP128(ctx->rk_enc[5], 0x20); + ctx->rk_enc[ 7] = KEYEXP128(ctx->rk_enc[6], 0x40); + ctx->rk_enc[ 8] = KEYEXP128(ctx->rk_enc[7], 0x80); + ctx->rk_enc[ 9] = KEYEXP128(ctx->rk_enc[8], 0x1B); + ctx->rk_enc[10] = KEYEXP128(ctx->rk_enc[9], 0x36); + break; + } + case 192: { + __m128i temp[2]; + ctx->rk_enc[ 0] = _mm_loadu_si128((const __m128i*) key); + + ctx->rk_enc[ 1] = _mm_loadu_si128((const __m128i*) (key+16)); + temp[0] = KEYEXP192(ctx->rk_enc[0], ctx->rk_enc[1], 0x01); + temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[1]); + ctx->rk_enc[ 1] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[1], (__m128d)temp[0], 0); + + ctx->rk_enc[ 2] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); + ctx->rk_enc[ 3] = KEYEXP192(temp[0], temp[1], 0x02); + + ctx->rk_enc[ 4] = KEYEXP192_2(ctx->rk_enc[3], temp[1]); + temp[0] = KEYEXP192(ctx->rk_enc[3], ctx->rk_enc[4], 0x04); + temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[4]); + ctx->rk_enc[ 4] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[4], (__m128d)temp[0], 0); + + ctx->rk_enc[ 5] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); + ctx->rk_enc[ 6] = KEYEXP192(temp[0], temp[1], 0x08); + + ctx->rk_enc[ 7] = KEYEXP192_2(ctx->rk_enc[6], temp[1]); + temp[0] = KEYEXP192(ctx->rk_enc[6], ctx->rk_enc[7], 0x10); + temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[7]); + ctx->rk_enc[ 7] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[7], (__m128d)temp[0], 0); + + ctx->rk_enc[ 8] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1); + ctx->rk_enc[ 9] = KEYEXP192(temp[0], temp[1], 0x20); + + ctx->rk_enc[10] = KEYEXP192_2(ctx->rk_enc[9], temp[1]); + temp[0] = KEYEXP192(ctx->rk_enc[9], ctx->rk_enc[10], 0x40); + temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[10]); + ctx->rk_enc[10] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[10], (__m128d) temp[0], 0); + + ctx->rk_enc[11] = (__m128i)_mm_shuffle_pd((__m128d)temp[0],(__m128d) temp[1], 1); + ctx->rk_enc[12] = KEYEXP192(temp[0], temp[1], 0x80); + break; + } + case 256: { + ctx->rk_enc[ 0] = _mm_loadu_si128((const __m128i*) key); + ctx->rk_enc[ 1] = _mm_loadu_si128((const __m128i*) (key+16)); + ctx->rk_enc[ 2] = KEYEXP256(ctx->rk_enc[0], ctx->rk_enc[1], 0x01); + ctx->rk_enc[ 3] = KEYEXP256_2(ctx->rk_enc[1], ctx->rk_enc[2]); + ctx->rk_enc[ 4] = KEYEXP256(ctx->rk_enc[2], ctx->rk_enc[3], 0x02); + ctx->rk_enc[ 5] = KEYEXP256_2(ctx->rk_enc[3], ctx->rk_enc[4]); + ctx->rk_enc[ 6] = KEYEXP256(ctx->rk_enc[4], ctx->rk_enc[5], 0x04); + ctx->rk_enc[ 7] = KEYEXP256_2(ctx->rk_enc[5], ctx->rk_enc[6]); + ctx->rk_enc[ 8] = KEYEXP256(ctx->rk_enc[6], ctx->rk_enc[7], 0x08); + ctx->rk_enc[ 9] = KEYEXP256_2(ctx->rk_enc[7], ctx->rk_enc[8]); + ctx->rk_enc[10] = KEYEXP256(ctx->rk_enc[8], ctx->rk_enc[9], 0x10); + ctx->rk_enc[11] = KEYEXP256_2(ctx->rk_enc[9], ctx->rk_enc[10]); + ctx->rk_enc[12] = KEYEXP256(ctx->rk_enc[10], ctx->rk_enc[11], 0x20); + ctx->rk_enc[13] = KEYEXP256_2(ctx->rk_enc[11], ctx->rk_enc[12]); + ctx->rk_enc[14] = KEYEXP256(ctx->rk_enc[12], ctx->rk_enc[13], 0x40); + break; + } } - } - // derive decryption keys - for (int i = 1; i < ctx->Nr; ++i) { - ctx->rk_dec[ctx->Nr - i] = _mm_aesimc_si128(ctx->rk_enc[i]); - } - ctx->rk_dec[0] = ctx->rk_enc[ctx->Nr]; + // derive decryption keys + for(int i = 1; i < ctx->Nr; ++i) { + ctx->rk_dec[ctx->Nr - i] = _mm_aesimc_si128(ctx->rk_enc[i]); + } + ctx->rk_dec[ 0] = ctx->rk_enc[ctx->Nr]; - return ctx->Nr; + return ctx->Nr; } static void aes_internal_encrypt (aes_context_t *ctx, const uint8_t pt[16], uint8_t ct[16]) { - __m128i tmp = _mm_loadu_si128((__m128i*)pt); - - tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]); - if(ctx->Nr > 10) { - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]); - if(ctx->Nr > 12) { - tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]); - tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]); + __m128i tmp = _mm_loadu_si128((__m128i*)pt); + + tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]); + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]); + } } - } - tmp = _mm_aesenclast_si128(tmp, ctx->rk_enc[ctx->Nr]); + tmp = _mm_aesenclast_si128 (tmp, ctx->rk_enc[ctx->Nr]); - _mm_storeu_si128((__m128i*) ct, tmp); + _mm_storeu_si128((__m128i*) ct, tmp); } static void aes_internal_decrypt (aes_context_t *ctx, const uint8_t ct[16], uint8_t pt[16]) { - __m128i tmp = _mm_loadu_si128((__m128i*)ct); - - tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]); - if(ctx->Nr > 10) { - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]); - if(ctx->Nr > 12) { - tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]); - tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]); + __m128i tmp = _mm_loadu_si128((__m128i*)ct); + + tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]); + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]); + } } - } - tmp = _mm_aesdeclast_si128(tmp, ctx->rk_enc[0]); + tmp = _mm_aesdeclast_si128 (tmp, ctx->rk_enc[ 0]); - _mm_storeu_si128((__m128i*) pt, tmp); + _mm_storeu_si128((__m128i*) pt, tmp); } @@ -350,243 +367,249 @@ static void aes_internal_decrypt (aes_context_t *ctx, const uint8_t ct[16], uint int aes_ecb_decrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) { - aes_internal_decrypt(ctx, in, out); + aes_internal_decrypt(ctx, in, out); - return AES_BLOCK_SIZE; + return AES_BLOCK_SIZE; } // not used int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) { - aes_internal_encrypt(ctx, in, out); + aes_internal_encrypt(ctx, in, out); - return AES_BLOCK_SIZE; + return AES_BLOCK_SIZE; } int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - int n; // number of blocks - int ret = (int)in_len & 15; // remainder - - __m128i ivec = _mm_loadu_si128((__m128i*)iv); - - for(n = in_len / 16; n != 0; n--) { - __m128i tmp = _mm_loadu_si128((__m128i*)in); - in += 16; - tmp = _mm_xor_si128(tmp, ivec); - - tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]); - if(ctx->Nr > 10) { - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]); - tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]); - if(ctx->Nr > 12) { - tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]); - tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]); - } - } - tmp = _mm_aesenclast_si128(tmp, ctx->rk_enc[ctx->Nr]); + int n; /* number of blocks */ + int ret = (int)in_len & 15; /* remainder */ + + __m128i ivec = _mm_loadu_si128((__m128i*)iv); + + for(n = in_len / 16; n != 0; n--) { + __m128i tmp = _mm_loadu_si128((__m128i*)in); + in += 16; + tmp = _mm_xor_si128(tmp, ivec); + + tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]); + tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]); + tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]); + } + } + tmp = _mm_aesenclast_si128 (tmp, ctx->rk_enc[ctx->Nr]); - ivec = tmp; + ivec = tmp; - _mm_storeu_si128((__m128i*)out, tmp); - out += 16; - } + _mm_storeu_si128((__m128i*)out, tmp); + out += 16; + } - return ret; + return ret; } int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - int n; // number of blocks - int ret = (int)in_len & 15; // remainder + int n; /* number of blocks */ + int ret = (int)in_len & 15; /* remainder */ - __m128i ivec = _mm_loadu_si128((__m128i*)iv); + __m128i ivec = _mm_loadu_si128((__m128i*)iv); - for(n = in_len / 16; n > 3; n -=4) { - __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; - __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; - __m128i tmp3 = _mm_loadu_si128((__m128i*)in); in += 16; - __m128i tmp4 = _mm_loadu_si128((__m128i*)in); in += 16; + // 4 parallel rails of AES decryption to reduce data dependencies in x86's deep pipelines + for(n = in_len / 16; n > 3; n -=4) { + __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp3 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp4 = _mm_loadu_si128((__m128i*)in); in += 16; - __m128i old_in1 = tmp1; - __m128i old_in2 = tmp2; - __m128i old_in3 = tmp3; - __m128i old_in4 = tmp4; + __m128i old_in1 = tmp1; + __m128i old_in2 = tmp2; + __m128i old_in3 = tmp3; + __m128i old_in4 = tmp4; - tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); - tmp3 = _mm_xor_si128 (tmp3, ctx->rk_dec[ 0]); tmp4 = _mm_xor_si128 (tmp4, ctx->rk_dec[ 0]); + tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); + tmp3 = _mm_xor_si128 (tmp3, ctx->rk_dec[ 0]); tmp4 = _mm_xor_si128 (tmp4, ctx->rk_dec[ 0]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 1]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 1]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 1]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 1]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 2]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 2]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 2]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 2]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 3]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 3]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 3]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 3]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 4]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 4]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 4]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 4]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 5]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 5]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 5]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 5]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 6]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 6]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 6]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 6]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 7]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 7]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 7]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 7]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 8]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 8]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 8]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 8]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 9]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 9]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[ 9]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[ 9]); - if(ctx->Nr > 10) { - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[10]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[10]); + if(ctx->Nr > 10) { + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[10]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[10]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); - tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[11]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[11]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); + tmp3 = _mm_aesdec_si128 (tmp3, ctx->rk_dec[11]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[11]); - if(ctx->Nr > 12) { - tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); - tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[12]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[12]); + if(ctx->Nr > 12) { + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); + tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[12]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[12]); - tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); - tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[13]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[13]); - } - } - tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); - tmp3 = _mm_aesdeclast_si128(tmp3, ctx->rk_enc[ 0]); tmp4 = _mm_aesdeclast_si128(tmp4, ctx->rk_enc[ 0]); - - tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); - tmp3 = _mm_xor_si128 (tmp3, old_in2); tmp4 = _mm_xor_si128 (tmp4, old_in3); - - ivec = old_in4; - - _mm_storeu_si128((__m128i*) out, tmp1); out += 16; - _mm_storeu_si128((__m128i*) out, tmp2); out += 16; - _mm_storeu_si128((__m128i*) out, tmp3); out += 16; - _mm_storeu_si128((__m128i*) out, tmp4); out += 16; - } // now: less than 4 blocks remaining - - if(n > 1) { // 2 or 3 blocks remaining --> this code handles two of them - n-= 2; - - __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; - __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; - - __m128i old_in1 = tmp1; - __m128i old_in2 = tmp2; - - tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); - if(ctx->Nr > 10) { - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); - tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); - if(ctx->Nr > 12) { - tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); - tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); - } + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); + tmp3 = _mm_aesdec_si128(tmp3, ctx->rk_dec[13]); tmp4 = _mm_aesdec_si128 (tmp4, ctx->rk_dec[13]); + } + } + tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); + tmp3 = _mm_aesdeclast_si128(tmp3, ctx->rk_enc[ 0]); tmp4 = _mm_aesdeclast_si128(tmp4, ctx->rk_enc[ 0]); + + tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); + tmp3 = _mm_xor_si128 (tmp3, old_in2); tmp4 = _mm_xor_si128 (tmp4, old_in3); + + ivec = old_in4; + + _mm_storeu_si128((__m128i*) out, tmp1); out += 16; + _mm_storeu_si128((__m128i*) out, tmp2); out += 16; + _mm_storeu_si128((__m128i*) out, tmp3); out += 16; + _mm_storeu_si128((__m128i*) out, tmp4); out += 16; } - tmp1 = _mm_aesdeclast_si128(tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); - - tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); - - ivec = old_in2; - - _mm_storeu_si128((__m128i*) out, tmp1); out += 16; - _mm_storeu_si128((__m128i*) out, tmp2); out += 16; - } - - if(n) { // one block remaining - __m128i tmp = _mm_loadu_si128((__m128i*)in); - __m128i old_in = tmp; - - tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]); - if(ctx->Nr > 10) { - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]); - tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]); - if(ctx->Nr > 12) { - tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]); - tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]); - } + // now: less than 4 blocks remaining + + // if 2 or 3 blocks remaining --> this code handles two of them + if(n > 1) { + n-= 2; + + __m128i tmp1 = _mm_loadu_si128((__m128i*)in); in += 16; + __m128i tmp2 = _mm_loadu_si128((__m128i*)in); in += 16; + + __m128i old_in1 = tmp1; + __m128i old_in2 = tmp2; + + tmp1 = _mm_xor_si128 (tmp1, ctx->rk_dec[ 0]); tmp2 = _mm_xor_si128 (tmp2, ctx->rk_dec[ 0]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 1]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 1]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 2]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 2]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 3]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 3]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 4]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 4]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 5]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 5]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 6]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 6]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 7]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 7]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 8]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 8]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[ 9]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[ 9]); + if(ctx->Nr > 10) { + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[10]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[10]); + tmp1 = _mm_aesdec_si128 (tmp1, ctx->rk_dec[11]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[11]); + if(ctx->Nr > 12) { + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[12]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[12]); + tmp1 = _mm_aesdec_si128(tmp1, ctx->rk_dec[13]); tmp2 = _mm_aesdec_si128 (tmp2, ctx->rk_dec[13]); + } + } + tmp1 = _mm_aesdeclast_si128 (tmp1, ctx->rk_enc[ 0]); tmp2 = _mm_aesdeclast_si128(tmp2, ctx->rk_enc[ 0]); + + tmp1 = _mm_xor_si128 (tmp1, ivec); tmp2 = _mm_xor_si128 (tmp2, old_in1); + + ivec = old_in2; + + _mm_storeu_si128((__m128i*) out, tmp1); out += 16; + _mm_storeu_si128((__m128i*) out, tmp2); out += 16; } - tmp = _mm_aesdeclast_si128(tmp, ctx->rk_enc[ 0]); - tmp = _mm_xor_si128 (tmp, ivec); + // one block remaining + if(n) { + __m128i tmp = _mm_loadu_si128((__m128i*)in); + __m128i old_in = tmp; + + tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]); + if(ctx->Nr > 10) { + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]); + tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]); + if(ctx->Nr > 12) { + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]); + tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]); + } + } + tmp = _mm_aesdeclast_si128 (tmp, ctx->rk_enc[ 0]); + + tmp = _mm_xor_si128 (tmp, ivec); - _mm_storeu_si128((__m128i*) out, tmp); - } + _mm_storeu_si128((__m128i*) out, tmp); + } - return ret; + return ret; } int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { - // allocate context... - *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); - if (!(*ctx)) - return -1; - // ...and fill her up - - // initialize data structures - - // check key size and make key size (given in bytes) dependant settings - switch(key_size) { - case AES128_KEY_BYTES: // 128 bit key size - break; - case AES192_KEY_BYTES: // 192 bit key size - break; - case AES256_KEY_BYTES: // 256 bit key size - break; - default: - traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); - return -1; - } - - // key materiel handling - aes_internal_key_setup ( *ctx, key, 8 * key_size); - return 0; + // allocate context... + *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); + if(!(*ctx)) + return -1; + // ...and fill her up: + + // initialize data structures + + // check key size and make key size (given in bytes) dependant settings + switch(key_size) { + case AES128_KEY_BYTES: // 128 bit key size + break; + case AES192_KEY_BYTES: // 192 bit key size + break; + case AES256_KEY_BYTES: // 256 bit key size + break; + default: + traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); + return -1; + } + + // key materiel handling + aes_internal_key_setup ( *ctx, key, 8 * key_size); + + return 0; } #else // plain C -------------------------------------------------------------------------- + // rijndael-alg-fst.c version 3.0 (December 2000) // optimised ANSI C code for the Rijndael cipher (now AES) // original authors: Vincent Rijmen @@ -949,9 +972,9 @@ static const uint32_t Td4[256] = { // for 128-bit blocks, Rijndael never uses more than 10 rcon values static const uint32_t rcon[] = { - 0x01000000, 0x02000000, 0x04000000, 0x08000000, - 0x10000000, 0x20000000, 0x40000000, 0x80000000, - 0x1B000000, 0x36000000 }; + 0x01000000, 0x02000000, 0x04000000, 0x08000000, + 0x10000000, 0x20000000, 0x40000000, 0x80000000, + 0x1B000000, 0x36000000 }; #define GETU32(p) (be32toh((*((uint32_t*)(p))))) @@ -968,224 +991,219 @@ static const uint32_t rcon[] = { #define m3(x) ((x) & 0xff000000) -/** - * Expand the cipher key into the encryption key schedule. - * - * @return the number of rounds for the given cipher key size. - */ +// expand the cipher key into the encryption key schedule and +// return the number of rounds for the given cipher key size static int aes_internal_key_setup_enc (uint32_t rk[/*4*(Nr + 1)*/], const uint8_t cipherKey[], int keyBits) { - int i = 0; - uint32_t temp; - - rk[0] = GETU32(cipherKey ); - rk[1] = GETU32(cipherKey + 4); - rk[2] = GETU32(cipherKey + 8); - rk[3] = GETU32(cipherKey + 12); - if (keyBits == 128) { - for (;;) { - temp = rk[3]; - rk[4] = rk[0] ^ - (Te4[b2(temp)] & 0xff000000) ^ - (Te4[b1(temp)] & 0x00ff0000) ^ - (Te4[b0(temp)] & 0x0000ff00) ^ - (Te4[b3(temp)] & 0x000000ff) ^ - rcon[i]; - rk[5] = rk[1] ^ rk[4]; - rk[6] = rk[2] ^ rk[5]; - rk[7] = rk[3] ^ rk[6]; - if (++i == 10) { - return 10; - } - rk += 4; - } - } - rk[4] = GETU32(cipherKey + 16); - rk[5] = GETU32(cipherKey + 20); - if (keyBits == 192) { - for (;;) { - temp = rk[ 5]; - rk[ 6] = rk[ 0] ^ - (Te4[b2(temp)] & 0xff000000) ^ - (Te4[b1(temp)] & 0x00ff0000) ^ - (Te4[b0(temp)] & 0x0000ff00) ^ - (Te4[b3(temp)] & 0x000000ff) ^ - rcon[i]; - rk[ 7] = rk[ 1] ^ rk[ 6]; - rk[ 8] = rk[ 2] ^ rk[ 7]; - rk[ 9] = rk[ 3] ^ rk[ 8]; - if (++i == 8) { - return 12; - } - rk[10] = rk[ 4] ^ rk[ 9]; - rk[11] = rk[ 5] ^ rk[10]; - rk += 6; - } - } - rk[6] = GETU32(cipherKey + 24); - rk[7] = GETU32(cipherKey + 28); - if (keyBits == 256) { - for (;;) { - temp = rk[ 7]; - rk[ 8] = rk[ 0] ^ - (Te4[b2(temp)] & 0xff000000) ^ - (Te4[b1(temp)] & 0x00ff0000) ^ - (Te4[b0(temp)] & 0x0000ff00) ^ - (Te4[b3(temp)] & 0x000000ff) ^ - rcon[i]; - rk[ 9] = rk[ 1] ^ rk[ 8]; - rk[10] = rk[ 2] ^ rk[ 9]; - rk[11] = rk[ 3] ^ rk[10]; - if (++i == 7) { - return 14; - } - temp = rk[11]; - rk[12] = rk[ 4] ^ - (Te4[b3(temp)] & 0xff000000) ^ - (Te4[b2(temp)] & 0x00ff0000) ^ - (Te4[b1(temp)] & 0x0000ff00) ^ - (Te4[b0(temp)] & 0x000000ff); - rk[13] = rk[ 5] ^ rk[12]; - rk[14] = rk[ 6] ^ rk[13]; - rk[15] = rk[ 7] ^ rk[14]; - - rk += 8; + int i = 0; + uint32_t temp; + + rk[0] = GETU32(cipherKey ); + rk[1] = GETU32(cipherKey + 4); + rk[2] = GETU32(cipherKey + 8); + rk[3] = GETU32(cipherKey + 12); + if(keyBits == 128) { + for(;;) { + temp = rk[3]; + rk[4] = rk[0] ^ + (Te4[b2(temp)] & 0xff000000) ^ + (Te4[b1(temp)] & 0x00ff0000) ^ + (Te4[b0(temp)] & 0x0000ff00) ^ + (Te4[b3(temp)] & 0x000000ff) ^ + rcon[i]; + rk[5] = rk[1] ^ rk[4]; + rk[6] = rk[2] ^ rk[5]; + rk[7] = rk[3] ^ rk[6]; + if(++i == 10) { + return 10; + } + rk += 4; + } + } + rk[4] = GETU32(cipherKey + 16); + rk[5] = GETU32(cipherKey + 20); + if(keyBits == 192) { + for(;;) { + temp = rk[ 5]; + rk[ 6] = rk[ 0] ^ + (Te4[b2(temp)] & 0xff000000) ^ + (Te4[b1(temp)] & 0x00ff0000) ^ + (Te4[b0(temp)] & 0x0000ff00) ^ + (Te4[b3(temp)] & 0x000000ff) ^ + rcon[i]; + rk[ 7] = rk[ 1] ^ rk[ 6]; + rk[ 8] = rk[ 2] ^ rk[ 7]; + rk[ 9] = rk[ 3] ^ rk[ 8]; + if(++i == 8) { + return 12; + } + rk[10] = rk[ 4] ^ rk[ 9]; + rk[11] = rk[ 5] ^ rk[10]; + rk += 6; + } + } + rk[6] = GETU32(cipherKey + 24); + rk[7] = GETU32(cipherKey + 28); + if(keyBits == 256) { + for(;;) { + temp = rk[ 7]; + rk[ 8] = rk[ 0] ^ + (Te4[b2(temp)] & 0xff000000) ^ + (Te4[b1(temp)] & 0x00ff0000) ^ + (Te4[b0(temp)] & 0x0000ff00) ^ + (Te4[b3(temp)] & 0x000000ff) ^ + rcon[i]; + rk[ 9] = rk[ 1] ^ rk[ 8]; + rk[10] = rk[ 2] ^ rk[ 9]; + rk[11] = rk[ 3] ^ rk[10]; + if(++i == 7) { + return 14; + } + temp = rk[11]; + rk[12] = rk[ 4] ^ + (Te4[b3(temp)] & 0xff000000) ^ + (Te4[b2(temp)] & 0x00ff0000) ^ + (Te4[b1(temp)] & 0x0000ff00) ^ + (Te4[b0(temp)] & 0x000000ff); + rk[13] = rk[ 5] ^ rk[12]; + rk[14] = rk[ 6] ^ rk[13]; + rk[15] = rk[ 7] ^ rk[14]; + rk += 8; } - } - return 0; + } + + return 0; } -/** - * Expand the cipher key into the decryption key schedule. - * - * @return the number of rounds for the given cipher key size. - */ -#define INVMIXCOLRK(n) rk[n] = Td0[b0(Te4[b3(rk[n])])] ^ Td1[b0(Te4[b2(rk[n])])] ^ Td2[b0(Te4[b1(rk[n])])] ^ Td3[b0(Te4[b0(rk[n])])] +#define INVMIXCOLRK(n) rk[n] = Td0[b0(Te4[b3(rk[n])])] ^ Td1[b0(Te4[b2(rk[n])])] ^ Td2[b0(Te4[b1(rk[n])])] ^ Td3[b0(Te4[b0(rk[n])])] +// expand the cipher key into the decryption key schedule and +// return the number of rounds for the given cipher key size static int aes_internal_key_setup_dec (uint32_t rk[/*4*(Nr + 1)*/], const uint8_t cipherKey[], int keyBits) { - int Nr, i, j; - uint32_t temp; - - // expand the cipher key - Nr = aes_internal_key_setup_enc(rk, cipherKey, keyBits); - // invert the order of the round keys - for (i = 0, j = 4*Nr; i < j; i += 4, j -= 4) { - temp = rk[i ]; rk[i ] = rk[j ]; rk[j ] = temp; - temp = rk[i + 1]; rk[i + 1] = rk[j + 1]; rk[j + 1] = temp; - temp = rk[i + 2]; rk[i + 2] = rk[j + 2]; rk[j + 2] = temp; - temp = rk[i + 3]; rk[i + 3] = rk[j + 3]; rk[j + 3] = temp; - } - // apply the inverse MixColumn transform to all round keys but the first and the last - for (i = 1; i < Nr; i++) { - rk += 4; - INVMIXCOLRK(0); - INVMIXCOLRK(1); - INVMIXCOLRK(2); - INVMIXCOLRK(3); - } - - return Nr; + int Nr, i, j; + uint32_t temp; + + // expand the cipher key + Nr = aes_internal_key_setup_enc(rk, cipherKey, keyBits); + // invert the order of the round keys + for(i = 0, j = 4*Nr; i < j; i += 4, j -= 4) { + temp = rk[i ]; rk[i ] = rk[j ]; rk[j ] = temp; + temp = rk[i + 1]; rk[i + 1] = rk[j + 1]; rk[j + 1] = temp; + temp = rk[i + 2]; rk[i + 2] = rk[j + 2]; rk[j + 2] = temp; + temp = rk[i + 3]; rk[i + 3] = rk[j + 3]; rk[j + 3] = temp; + } + + // apply the inverse MixColumn transform to all round keys but the first and the last + for(i = 1; i < Nr; i++) { + rk += 4; + INVMIXCOLRK(0); + INVMIXCOLRK(1); + INVMIXCOLRK(2); + INVMIXCOLRK(3); + } + + return Nr; } #define AES_ENC_ROUND(DST, SRC, round) \ - DST##0 = Te0[b3(SRC##0)] ^ Te1[b2(SRC##1)] ^ Te2[b1(SRC##2)] ^ Te3[b0(SRC##3)] ^ rk[4 * round + 0]; \ - DST##1 = Te0[b3(SRC##1)] ^ Te1[b2(SRC##2)] ^ Te2[b1(SRC##3)] ^ Te3[b0(SRC##0)] ^ rk[4 * round + 1]; \ - DST##2 = Te0[b3(SRC##2)] ^ Te1[b2(SRC##3)] ^ Te2[b1(SRC##0)] ^ Te3[b0(SRC##1)] ^ rk[4 * round + 2]; \ - DST##3 = Te0[b3(SRC##3)] ^ Te1[b2(SRC##0)] ^ Te2[b1(SRC##1)] ^ Te3[b0(SRC##2)] ^ rk[4 * round + 3]; + DST##0 = Te0[b3(SRC##0)] ^ Te1[b2(SRC##1)] ^ Te2[b1(SRC##2)] ^ Te3[b0(SRC##3)] ^ rk[4 * round + 0]; \ + DST##1 = Te0[b3(SRC##1)] ^ Te1[b2(SRC##2)] ^ Te2[b1(SRC##3)] ^ Te3[b0(SRC##0)] ^ rk[4 * round + 1]; \ + DST##2 = Te0[b3(SRC##2)] ^ Te1[b2(SRC##3)] ^ Te2[b1(SRC##0)] ^ Te3[b0(SRC##1)] ^ rk[4 * round + 2]; \ + DST##3 = Te0[b3(SRC##3)] ^ Te1[b2(SRC##0)] ^ Te2[b1(SRC##1)] ^ Te3[b0(SRC##2)] ^ rk[4 * round + 3]; static void aes_internal_encrypt (const uint32_t rk[/*4*(Nr + 1)*/], int Nr, const uint8_t pt[16], uint8_t ct[16]) { - uint32_t s0, s1, s2, s3, t0, t1, t2, t3; - - // map byte array block to cipher state and add initial round key - s0 = GETU32(pt ) ^ rk[0]; - s1 = GETU32(pt + 4) ^ rk[1]; - s2 = GETU32(pt + 8) ^ rk[2]; - s3 = GETU32(pt + 12) ^ rk[3]; - - AES_ENC_ROUND(t, s, 1); - AES_ENC_ROUND(s, t, 2); - AES_ENC_ROUND(t, s, 3); - AES_ENC_ROUND(s, t, 4); - AES_ENC_ROUND(t, s, 5); - AES_ENC_ROUND(s, t, 6); - AES_ENC_ROUND(t, s, 7); - AES_ENC_ROUND(s, t, 8); - AES_ENC_ROUND(t, s, 9); - - if(Nr > 10) { - AES_ENC_ROUND(s, t, 10); - AES_ENC_ROUND(t, s, 11); - if(Nr > 12) { - AES_ENC_ROUND(s, t, 12); - AES_ENC_ROUND(t, s, 13); + uint32_t s0, s1, s2, s3, t0, t1, t2, t3; + + // map byte array block to cipher state and add initial round key + s0 = GETU32(pt ) ^ rk[0]; + s1 = GETU32(pt + 4) ^ rk[1]; + s2 = GETU32(pt + 8) ^ rk[2]; + s3 = GETU32(pt + 12) ^ rk[3]; + + AES_ENC_ROUND(t, s, 1); + AES_ENC_ROUND(s, t, 2); + AES_ENC_ROUND(t, s, 3); + AES_ENC_ROUND(s, t, 4); + AES_ENC_ROUND(t, s, 5); + AES_ENC_ROUND(s, t, 6); + AES_ENC_ROUND(t, s, 7); + AES_ENC_ROUND(s, t, 8); + AES_ENC_ROUND(t, s, 9); + + if(Nr > 10) { + AES_ENC_ROUND(s, t, 10); + AES_ENC_ROUND(t, s, 11); + if(Nr > 12) { + AES_ENC_ROUND(s, t, 12); + AES_ENC_ROUND(t, s, 13); + } } - } - - rk += Nr << 2; - // apply last round and map cipher state to byte array block - s0 = m3(Te4[b3(t0)]) ^ m2(Te4[b2(t1)]) ^ m1(Te4[b1(t2)]) ^ m0(Te4[b0(t3)]) ^ rk[0]; - PUTU32(ct , s0); - s1 = m3(Te4[b3(t1)]) ^ m2(Te4[b2(t2)]) ^ m1(Te4[b1(t3)]) ^ m0(Te4[b0(t0)]) ^ rk[1]; - PUTU32(ct + 4, s1); - s2 = m3(Te4[b3(t2)]) ^ m2(Te4[b2(t3)]) ^ m1(Te4[b1(t0)]) ^ m0(Te4[b0(t1)]) ^ rk[2]; - PUTU32(ct + 8, s2); - s3 = m3(Te4[b3(t3)]) ^ m2(Te4[b2(t0)]) ^ m1(Te4[b1(t1)]) ^ m0(Te4[b0(t2)]) ^ rk[3]; - PUTU32(ct + 12, s3); + + rk += Nr << 2; + // apply last round and map cipher state to byte array block + s0 = m3(Te4[b3(t0)]) ^ m2(Te4[b2(t1)]) ^ m1(Te4[b1(t2)]) ^ m0(Te4[b0(t3)]) ^ rk[0]; + PUTU32(ct , s0); + s1 = m3(Te4[b3(t1)]) ^ m2(Te4[b2(t2)]) ^ m1(Te4[b1(t3)]) ^ m0(Te4[b0(t0)]) ^ rk[1]; + PUTU32(ct + 4, s1); + s2 = m3(Te4[b3(t2)]) ^ m2(Te4[b2(t3)]) ^ m1(Te4[b1(t0)]) ^ m0(Te4[b0(t1)]) ^ rk[2]; + PUTU32(ct + 8, s2); + s3 = m3(Te4[b3(t3)]) ^ m2(Te4[b2(t0)]) ^ m1(Te4[b1(t1)]) ^ m0(Te4[b0(t2)]) ^ rk[3]; + PUTU32(ct + 12, s3); } #define AES_DEC_ROUND(DST, SRC, round) \ - DST##0 = Td0[b3(SRC##0)] ^ Td1[b2(SRC##3)] ^ Td2[b1(SRC##2)] ^ Td3[b0(SRC##1)] ^ rk[4 * round + 0]; \ - DST##1 = Td0[b3(SRC##1)] ^ Td1[b2(SRC##0)] ^ Td2[b1(SRC##3)] ^ Td3[b0(SRC##2)] ^ rk[4 * round + 1]; \ - DST##2 = Td0[b3(SRC##2)] ^ Td1[b2(SRC##1)] ^ Td2[b1(SRC##0)] ^ Td3[b0(SRC##3)] ^ rk[4 * round + 2]; \ - DST##3 = Td0[b3(SRC##3)] ^ Td1[b2(SRC##2)] ^ Td2[b1(SRC##1)] ^ Td3[b0(SRC##0)] ^ rk[4 * round + 3]; + DST##0 = Td0[b3(SRC##0)] ^ Td1[b2(SRC##3)] ^ Td2[b1(SRC##2)] ^ Td3[b0(SRC##1)] ^ rk[4 * round + 0]; \ + DST##1 = Td0[b3(SRC##1)] ^ Td1[b2(SRC##0)] ^ Td2[b1(SRC##3)] ^ Td3[b0(SRC##2)] ^ rk[4 * round + 1]; \ + DST##2 = Td0[b3(SRC##2)] ^ Td1[b2(SRC##1)] ^ Td2[b1(SRC##0)] ^ Td3[b0(SRC##3)] ^ rk[4 * round + 2]; \ + DST##3 = Td0[b3(SRC##3)] ^ Td1[b2(SRC##2)] ^ Td2[b1(SRC##1)] ^ Td3[b0(SRC##0)] ^ rk[4 * round + 3]; static void aes_internal_decrypt (const uint32_t rk[/*4*(Nr + 1)*/], int Nr, const uint8_t ct[16], uint8_t pt[16]) { - uint32_t s0, s1, s2, s3, t0, t1, t2, t3; - - // map byte array block to cipher state and add initial round key - s0 = GETU32(ct ) ^ rk[0]; - s1 = GETU32(ct + 4) ^ rk[1]; - s2 = GETU32(ct + 8) ^ rk[2]; - s3 = GETU32(ct + 12) ^ rk[3]; - - AES_DEC_ROUND(t, s, 1); - AES_DEC_ROUND(s, t, 2); - AES_DEC_ROUND(t, s, 3); - AES_DEC_ROUND(s, t, 4); - AES_DEC_ROUND(t, s, 5); - AES_DEC_ROUND(s, t, 6); - AES_DEC_ROUND(t, s, 7); - AES_DEC_ROUND(s, t, 8); - AES_DEC_ROUND(t, s, 9); - - if(Nr > 10) { - AES_DEC_ROUND(s, t, 10); - AES_DEC_ROUND(t, s, 11); - if(Nr > 12) { - AES_DEC_ROUND(s, t, 12); - AES_DEC_ROUND(t, s, 13); + uint32_t s0, s1, s2, s3, t0, t1, t2, t3; + + // map byte array block to cipher state and add initial round key + s0 = GETU32(ct ) ^ rk[0]; + s1 = GETU32(ct + 4) ^ rk[1]; + s2 = GETU32(ct + 8) ^ rk[2]; + s3 = GETU32(ct + 12) ^ rk[3]; + + AES_DEC_ROUND(t, s, 1); + AES_DEC_ROUND(s, t, 2); + AES_DEC_ROUND(t, s, 3); + AES_DEC_ROUND(s, t, 4); + AES_DEC_ROUND(t, s, 5); + AES_DEC_ROUND(s, t, 6); + AES_DEC_ROUND(t, s, 7); + AES_DEC_ROUND(s, t, 8); + AES_DEC_ROUND(t, s, 9); + + if(Nr > 10) { + AES_DEC_ROUND(s, t, 10); + AES_DEC_ROUND(t, s, 11); + if(Nr > 12) { + AES_DEC_ROUND(s, t, 12); + AES_DEC_ROUND(t, s, 13); + } } - } - - rk += Nr << 2; - // apply last round and map cipher state to byte array block - s0 = m3(Td4[b3(t0)]) ^ m2(Td4[b2(t3)]) ^ m1(Td4[b1(t2)]) ^ m0(Td4[b0(t1)]) ^ rk[0]; - PUTU32(pt , s0); - s1 = m3(Td4[b3(t1)]) ^ m2(Td4[b2(t0)]) ^ m1(Td4[b1(t3)]) ^ m0(Td4[b0(t2)]) ^ rk[1]; - PUTU32(pt + 4, s1); - s2 = m3(Td4[b3(t2)]) ^ m2(Td4[b2(t1)]) ^ m1(Td4[b1(t0)]) ^ m0(Td4[b0(t3)]) ^ rk[2]; - PUTU32(pt + 8, s2); - s3 = m3(Td4[b3(t3)]) ^ m2(Td4[b2(t2)]) ^ m1(Td4[b1(t1)]) ^ m0(Td4[b0(t0)]) ^ rk[3]; - PUTU32(pt + 12, s3); + + rk += Nr << 2; + // apply last round and map cipher state to byte array block + s0 = m3(Td4[b3(t0)]) ^ m2(Td4[b2(t3)]) ^ m1(Td4[b1(t2)]) ^ m0(Td4[b0(t1)]) ^ rk[0]; + PUTU32(pt , s0); + s1 = m3(Td4[b3(t1)]) ^ m2(Td4[b2(t0)]) ^ m1(Td4[b1(t3)]) ^ m0(Td4[b0(t2)]) ^ rk[1]; + PUTU32(pt + 4, s1); + s2 = m3(Td4[b3(t2)]) ^ m2(Td4[b2(t1)]) ^ m1(Td4[b1(t0)]) ^ m0(Td4[b0(t3)]) ^ rk[2]; + PUTU32(pt + 8, s2); + s3 = m3(Td4[b3(t3)]) ^ m2(Td4[b2(t2)]) ^ m1(Td4[b1(t1)]) ^ m0(Td4[b0(t0)]) ^ rk[3]; + PUTU32(pt + 12, s3); } @@ -1194,224 +1212,103 @@ static void aes_internal_decrypt (const uint32_t rk[/*4*(Nr + 1)*/], int Nr, con int aes_ecb_decrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) { - aes_internal_decrypt(ctx->dec_rk, ctx->Nr, in, out); + aes_internal_decrypt(ctx->dec_rk, ctx->Nr, in, out); - return AES_BLOCK_SIZE; + return AES_BLOCK_SIZE; } // not used int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) { - aes_internal_encrypt(ctx->enc_rk, ctx->Nr, in, out); + aes_internal_encrypt(ctx->enc_rk, ctx->Nr, in, out); - return AES_BLOCK_SIZE; + return AES_BLOCK_SIZE; } #define fix_xor(target, source) *(uint32_t*)&(target)[0] = *(uint32_t*)&(target)[0] ^ *(uint32_t*)&(source)[0]; *(uint32_t*)&(target)[4] = *(uint32_t*)&(target)[4] ^ *(uint32_t*)&(source)[4]; \ *(uint32_t*)&(target)[8] = *(uint32_t*)&(target)[8] ^ *(uint32_t*)&(source)[8]; *(uint32_t*)&(target)[12] = *(uint32_t*)&(target)[12] ^ *(uint32_t*)&(source)[12]; + int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - uint8_t tmp[AES_BLOCK_SIZE]; - size_t i; - size_t n; + uint8_t tmp[AES_BLOCK_SIZE]; + size_t i; + size_t n; + + memcpy(tmp, iv, AES_BLOCK_SIZE); - memcpy(tmp, iv, AES_BLOCK_SIZE); + n = in_len / AES_BLOCK_SIZE; + for(i=0; i < n; i++) { + fix_xor(tmp, &in[i * AES_BLOCK_SIZE]); + aes_internal_encrypt(ctx->enc_rk, ctx->Nr, tmp, tmp); + memcpy(&out[i * AES_BLOCK_SIZE], tmp, AES_BLOCK_SIZE); + } - n = in_len / AES_BLOCK_SIZE; - for(i=0; i < n; i++) { - fix_xor(tmp, &in[i * AES_BLOCK_SIZE]); - aes_internal_encrypt(ctx->enc_rk, ctx->Nr, tmp, tmp); - memcpy(&out[i * AES_BLOCK_SIZE], tmp, AES_BLOCK_SIZE); - } - return n * AES_BLOCK_SIZE; + return n * AES_BLOCK_SIZE; } int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len, const unsigned char *iv, aes_context_t *ctx) { - uint8_t tmp[AES_BLOCK_SIZE]; - uint8_t old[AES_BLOCK_SIZE]; - size_t i; - size_t n; + uint8_t tmp[AES_BLOCK_SIZE]; + uint8_t old[AES_BLOCK_SIZE]; + size_t i; + size_t n; - memcpy(tmp, iv, AES_BLOCK_SIZE); - - n = in_len / AES_BLOCK_SIZE; - for(i=0; i < n; i++) { - memcpy(old, &in[i * AES_BLOCK_SIZE], AES_BLOCK_SIZE); - aes_internal_decrypt(ctx->dec_rk, ctx->Nr, &in[i * AES_BLOCK_SIZE], &out[i * AES_BLOCK_SIZE]); - fix_xor(&out[i * AES_BLOCK_SIZE], tmp); - memcpy(tmp, old, AES_BLOCK_SIZE); - } - - return n * AES_BLOCK_SIZE; -} + memcpy(tmp, iv, AES_BLOCK_SIZE); -int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { + n = in_len / AES_BLOCK_SIZE; + for(i=0; i < n; i++) { + memcpy(old, &in[i * AES_BLOCK_SIZE], AES_BLOCK_SIZE); + aes_internal_decrypt(ctx->dec_rk, ctx->Nr, &in[i * AES_BLOCK_SIZE], &out[i * AES_BLOCK_SIZE]); + fix_xor(&out[i * AES_BLOCK_SIZE], tmp); + memcpy(tmp, old, AES_BLOCK_SIZE); + } - // allocate context... - *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); - if (!(*ctx)) - return -1; - // ...and fill her up - - // initialize data structures - - // check key size and make key size (given in bytes) dependant settings - switch(key_size) { - case AES128_KEY_BYTES: // 128 bit key size - break; - case AES192_KEY_BYTES: // 192 bit key size - break; - case AES256_KEY_BYTES: // 256 bit key size - break; - default: - traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); - return -1; - } - - // key materiel handling - (*ctx)->Nr = aes_internal_key_setup_enc((*ctx)->enc_rk/*[4*(Nr + 1)]*/, key, 8 * key_size); - aes_internal_key_setup_dec((*ctx)->dec_rk/*[4*(Nr + 1)]*/, key, 8 * key_size); - return 0; + return n * AES_BLOCK_SIZE; } -#endif // openSSL 1.1, AES-NI, plain C ---------------------------------------------------- - -int aes_deinit (aes_context_t *ctx) { - - if (ctx) free (ctx); - - return 0; -} +int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { + // allocate context... + *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); + if(!(*ctx)) + return -1; + // ...and fill her up: + + // initialize data structures + + // check key size and make key size (given in bytes) dependant settings + switch(key_size) { + case AES128_KEY_BYTES: // 128 bit key size + break; + case AES192_KEY_BYTES: // 192 bit key size + break; + case AES256_KEY_BYTES: // 256 bit key size + break; + default: + traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); + return -1; + } -// --- for testing ------------------------------------------------------------------------ -// --- remove when done --- - -/* int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) { - - // allocate context... - *ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t)); - if (!(*ctx)) - return -1; - // ...and fill her up - - // initialize data structures -#ifdef HAVE_OPENSSL_1_1 - if(!((*ctx)->enc_ctx = EVP_CIPHER_CTX_new())) { - traceEvent(TRACE_ERROR, "aes_init openssl's evp_* encryption context creation failed: %s", - openssl_err_as_string()); - return(-1); - } - if(!((*ctx)->dec_ctx = EVP_CIPHER_CTX_new())) { - traceEvent(TRACE_ERROR, "aes_init openssl's evp_* decryption context creation failed: %s", - openssl_err_as_string()); - return(-1); - } -#endif - - // check key size and make key size (given in bytes) dependant settings - switch(key_size) { - case AES128_KEY_BYTES: // 128 bit key size -#ifdef HAVE_OPENSSL_1_1 - (*ctx)->cipher = EVP_aes_128_cbc(); -#endif - break; - case AES192_KEY_BYTES: // 192 bit key size -#ifdef HAVE_OPENSSL_1_1 - (*ctx)->cipher = EVP_aes_192_cbc(); -#endif - break; - case AES256_KEY_BYTES: // 256 bit key size -#ifdef HAVE_OPENSSL_1_1 - (*ctx)->cipher = EVP_aes_256_cbc(); -#endif - break; - default: - traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size); - return -1; - } - - // key materiel handling -#ifdef HAVE_OPENSSL_1_1 - memcpy((*ctx)->key, key, key_size); - AES_set_decrypt_key(key, key_size * 8, &((*ctx)->ecb_dec_key)); -#else - AES_set_encrypt_key(key, key_size * 8, &((*ctx)->enc_key)); - AES_set_decrypt_key(key, key_size * 8, &((*ctx)->dec_key)); -#endif - - return 0; + // key materiel handling + (*ctx)->Nr = aes_internal_key_setup_enc((*ctx)->enc_rk/*[4*(Nr + 1)]*/, key, 8 * key_size); + aes_internal_key_setup_dec((*ctx)->dec_rk/*[4*(Nr + 1)]*/, key, 8 * key_size); + return 0; } -#ifdef TEST_AES -int main () { - - aes_context_t *ctx; - - -// *ctx = malloc(sizeof(aes_context_t)); - +#endif // openSSL 1.1, AES-NI, plain C ---------------------------------------------------------------------------- -// uint8_t key[32] = {0}; -// 128 bit key 0 --> 0336763e966d92595a567cc9ce537f5e -// uint8_t pt[16] = {0xf3, 0x44, 0x81, 0xec, 0x3c, 0xc6, 0x27, 0xba, -// 0xcd, 0x5d, 0xc3, 0xfb, 0x08, 0xf2, 0x73, 0xe6 }; -// 256 bit key 0 --> 5c9d844ed46f9885085e5d6a4f94c7d7 -// uint8_t pt[16] = {0x01, 0x47, 0x30, 0xf8, 0x0a, 0xc6, 0x25, 0xfe, -// 0x84, 0xf0, 0x26, 0xc6, 0x0b, 0xfd, 0x54, 0x7d }; - - uint8_t pt[16] = {0}; -// 0 pt --> 6d251e6944b051e04eaa6fb4dbf78465 - uint8_t key[16] = {0x10, 0xa5, 0x88, 0x69, 0xd7, 0x4b, 0xe5, 0xa3, - 0x74, 0xcf, 0x86, 0x7c, 0xfb, 0x47, 0x38, 0x59 }; - - uint8_t ct[16] = {0}; - int i; - - // aes_internal_key_setup (ctx, key, 8 * sizeof(key)); - aes_init (key, sizeof(key), &ctx); - - printf ("Nr = %u\n",(ctx)->Nr); - memset (pt, 0, 16); - - - for(i = 0; i < 16; i++) - printf ("%02x",pt[i]); - printf ("--- pt\n"); - - aes_internal_encrypt((ctx), pt, ct); - memset (pt, 4, 16); - - for(i = 0; i < 16; i++) - printf ("%02x",ct[i]); - printf ("--- ct\n"); - - - - printf ("Nr = %u\n",(ctx)->Nr); - printf ("Nr = %u\n",(ctx)->Nr); - - aes_internal_decrypt((ctx), ct, pt); - memset (ct, 9, 16); - for(i = 0; i < 16; i++) - printf ("%02x",pt[i]); - printf ("--- pt\n"); +int aes_deinit (aes_context_t *ctx) { - aes_internal_encrypt((ctx), pt, ct); + if(ctx) free(ctx); - for(i = 0; i < 16; i++) - printf ("%02x",ct[i]); - printf ("--- ct\n"); + return 0; } -#endif -*/