@ -365,48 +365,194 @@ int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t
}
}
# define fix_xor(target, source) *(uint32_t*)&(target)[0] = *(uint32_t*)&(target)[0] ^ *(uint32_t*)&(source)[0]; *(uint32_t*)&(target)[4] = *(uint32_t*)&(target)[4] ^ *(uint32_t*)&(source)[4]; \
* ( uint32_t * ) & ( target ) [ 8 ] = * ( uint32_t * ) & ( target ) [ 8 ] ^ * ( uint32_t * ) & ( source ) [ 8 ] ; * ( uint32_t * ) & ( target ) [ 12 ] = * ( uint32_t * ) & ( target ) [ 12 ] ^ * ( uint32_t * ) & ( source ) [ 12 ] ;
int aes_cbc_encrypt ( unsigned char * out , const unsigned char * in , size_t in_len ,
int aes_cbc_encrypt ( unsigned char * out , const unsigned char * in , size_t in_len ,
const unsigned char * iv , aes_context_t * ctx ) {
const unsigned char * iv , aes_context_t * ctx ) {
uint8_t tmp [ AES_BLOCK_SIZE ] ;
int n ; // number of blocks
size_t i ;
int ret = ( int ) in_len & 15 ; // remainder
size_t n ;
__m128i ivec = _mm_loadu_si128 ( ( __m128i * ) iv ) ;
for ( n = in_len / 16 ; n ! = 0 ; n - - ) {
__m128i tmp = _mm_loadu_si128 ( ( __m128i * ) in ) ;
in + = 16 ;
tmp = _mm_xor_si128 ( tmp , ivec ) ;
tmp = _mm_xor_si128 ( tmp , ctx - > rk_enc [ 0 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 1 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 2 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 3 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 4 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 5 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 6 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 7 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 8 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 9 ] ) ;
if ( ctx - > Nr > 10 ) {
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 10 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 11 ] ) ;
if ( ctx - > Nr > 12 ) {
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 12 ] ) ;
tmp = _mm_aesenc_si128 ( tmp , ctx - > rk_enc [ 13 ] ) ;
}
}
tmp = _mm_aesenclast_si128 ( tmp , ctx - > rk_enc [ ctx - > Nr ] ) ;
memcpy ( tmp , iv , AES_BLOCK_SIZE ) ;
ivec = tmp ;
n = in_len / AES_BLOCK_SIZE ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp ) ;
for ( i = 0 ; i < n ; i + + ) {
out + = 16 ;
fix_xor ( tmp , & in [ i * AES_BLOCK_SIZE ] ) ;
aes_internal_encrypt ( ctx , tmp , tmp ) ;
memcpy ( & out [ i * AES_BLOCK_SIZE ] , tmp , AES_BLOCK_SIZE ) ;
}
}
return n * AES_BLOCK_SIZE ;
return ret ;
}
}
int aes_cbc_decrypt ( unsigned char * out , const unsigned char * in , size_t in_len ,
int aes_cbc_decrypt ( unsigned char * out , const unsigned char * in , size_t in_len ,
const unsigned char * iv , aes_context_t * ctx ) {
const unsigned char * iv , aes_context_t * ctx ) {
uint8_t tmp [ AES_BLOCK_SIZE ] ;
int n ; // number of blocks
uint8_t old [ AES_BLOCK_SIZE ] ;
int ret = ( int ) in_len & 15 ; // remainder
size_t i ;
size_t n ;
memcpy ( tmp , iv , AES_BLOCK_SIZE ) ;
__m128i ivec = _mm_loadu_si128 ( ( __m128i * ) iv ) ;
n = in_len / AES_BLOCK_SIZE ;
for ( n = in_len / 16 ; n > 3 ; n - = 4 ) {
for ( i = 0 ; i < n ; i + + ) {
__m128i tmp1 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
memcpy ( old , & in [ i * AES_BLOCK_SIZE ] , AES_BLOCK_SIZE ) ;
__m128i tmp2 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
aes_internal_decrypt ( ctx , & in [ i * AES_BLOCK_SIZE ] , & out [ i * AES_BLOCK_SIZE ] ) ;
__m128i tmp3 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
fix_xor ( & out [ i * AES_BLOCK_SIZE ] , tmp ) ;
__m128i tmp4 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
memcpy ( tmp , old , AES_BLOCK_SIZE ) ;
__m128i old_in1 = tmp1 ;
__m128i old_in2 = tmp2 ;
__m128i old_in3 = tmp3 ;
__m128i old_in4 = tmp4 ;
tmp1 = _mm_xor_si128 ( tmp1 , ctx - > rk_dec [ 0 ] ) ; tmp2 = _mm_xor_si128 ( tmp2 , ctx - > rk_dec [ 0 ] ) ;
tmp3 = _mm_xor_si128 ( tmp3 , ctx - > rk_dec [ 0 ] ) ; tmp4 = _mm_xor_si128 ( tmp4 , ctx - > rk_dec [ 0 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 1 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 1 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 1 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 1 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 2 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 2 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 2 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 2 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 3 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 3 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 3 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 3 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 4 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 4 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 4 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 4 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 5 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 5 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 5 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 5 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 6 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 6 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 6 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 6 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 7 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 7 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 7 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 7 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 8 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 8 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 8 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 8 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 9 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 9 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 9 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 9 ] ) ;
if ( ctx - > Nr > 10 ) {
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 10 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 10 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 10 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 10 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 11 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 11 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 11 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 11 ] ) ;
if ( ctx - > Nr > 12 ) {
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 12 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 12 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 12 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 12 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 13 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 13 ] ) ;
tmp3 = _mm_aesdec_si128 ( tmp3 , ctx - > rk_dec [ 13 ] ) ; tmp4 = _mm_aesdec_si128 ( tmp4 , ctx - > rk_dec [ 13 ] ) ;
}
}
tmp1 = _mm_aesdeclast_si128 ( tmp1 , ctx - > rk_enc [ 0 ] ) ; tmp2 = _mm_aesdeclast_si128 ( tmp2 , ctx - > rk_enc [ 0 ] ) ;
tmp3 = _mm_aesdeclast_si128 ( tmp3 , ctx - > rk_enc [ 0 ] ) ; tmp4 = _mm_aesdeclast_si128 ( tmp4 , ctx - > rk_enc [ 0 ] ) ;
tmp1 = _mm_xor_si128 ( tmp1 , ivec ) ; tmp2 = _mm_xor_si128 ( tmp2 , old_in1 ) ;
tmp3 = _mm_xor_si128 ( tmp3 , old_in2 ) ; tmp4 = _mm_xor_si128 ( tmp4 , old_in3 ) ;
ivec = old_in4 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp1 ) ; out + = 16 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp2 ) ; out + = 16 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp3 ) ; out + = 16 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp4 ) ; out + = 16 ;
} // now: less than 4 blocks remaining
if ( n > 1 ) { // 2 or 3 blocks remaining --> this code handles two of them
n - = 2 ;
__m128i tmp1 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
__m128i tmp2 = _mm_loadu_si128 ( ( __m128i * ) in ) ; in + = 16 ;
__m128i old_in1 = tmp1 ;
__m128i old_in2 = tmp2 ;
tmp1 = _mm_xor_si128 ( tmp1 , ctx - > rk_dec [ 0 ] ) ; tmp2 = _mm_xor_si128 ( tmp2 , ctx - > rk_dec [ 0 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 1 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 1 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 2 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 2 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 3 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 3 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 4 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 4 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 5 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 5 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 6 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 6 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 7 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 7 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 8 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 8 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 9 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 9 ] ) ;
if ( ctx - > Nr > 10 ) {
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 10 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 10 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 11 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 11 ] ) ;
if ( ctx - > Nr > 12 ) {
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 12 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 12 ] ) ;
tmp1 = _mm_aesdec_si128 ( tmp1 , ctx - > rk_dec [ 13 ] ) ; tmp2 = _mm_aesdec_si128 ( tmp2 , ctx - > rk_dec [ 13 ] ) ;
}
}
tmp1 = _mm_aesdeclast_si128 ( tmp1 , ctx - > rk_enc [ 0 ] ) ; tmp2 = _mm_aesdeclast_si128 ( tmp2 , ctx - > rk_enc [ 0 ] ) ;
tmp1 = _mm_xor_si128 ( tmp1 , ivec ) ; tmp2 = _mm_xor_si128 ( tmp2 , old_in1 ) ;
ivec = old_in2 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp1 ) ; out + = 16 ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp2 ) ; out + = 16 ;
}
}
return n * AES_BLOCK_SIZE ;
if ( n ) { // one block remaining
__m128i tmp = _mm_loadu_si128 ( ( __m128i * ) in ) ;
__m128i old_in = tmp ;
tmp = _mm_xor_si128 ( tmp , ctx - > rk_dec [ 0 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 1 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 2 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 3 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 4 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 5 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 6 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 7 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 8 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 9 ] ) ;
if ( ctx - > Nr > 10 ) {
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 10 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 11 ] ) ;
if ( ctx - > Nr > 12 ) {
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 12 ] ) ;
tmp = _mm_aesdec_si128 ( tmp , ctx - > rk_dec [ 13 ] ) ;
}
}
tmp = _mm_aesdeclast_si128 ( tmp , ctx - > rk_enc [ 0 ] ) ;
tmp = _mm_xor_si128 ( tmp , ivec ) ;
_mm_storeu_si128 ( ( __m128i * ) out , tmp ) ;
}
return ret ;
}
}