OCB: Make primitive AES API explicit

Explicitly define the primitive AES API used by the internal OCB
implementation, and move it into its own namespace (ocb_aes). This will
ease future implementation changes.

Also make some style fixes to affected lines: Replace C-style casts
with C++-style casts, add some missing spaces in argument lists, and
remove some `inline` that the compiler will ignore.

Bug: https://github.com/mobile-shell/mosh/issues/1174
This commit is contained in:
Benjamin Barenblat
2022-06-27 17:52:11 -04:00
committed by Alex Chernyakhovsky
parent 0a30c5acd5
commit ad85b90505
+66 -36
View File
@@ -351,7 +351,7 @@
#endif
/* ----------------------------------------------------------------------- */
/* AES - Code uses OpenSSL API. Other implementations get mapped to it. */
/* AES */
/* ----------------------------------------------------------------------- */
/*---------------*/
@@ -360,21 +360,43 @@
#include <openssl/aes.h> /* http://openssl.org/ */
namespace ocb_aes {
typedef AES_KEY KEY;
static void set_encrypt_key(const unsigned char *user_key, int bits, KEY *key) {
AES_set_encrypt_key(user_key, bits, key);
}
static void set_decrypt_key(const unsigned char *user_key, int bits, KEY *key) {
AES_set_decrypt_key(user_key, bits, key);
}
static void encrypt(const unsigned char *in, unsigned char *out, KEY *key) {
return AES_encrypt(in, out, key);
}
static void decrypt(const unsigned char *in, unsigned char *out, KEY *key) {
return AES_decrypt(in, out, key);
}
/* How to ECB encrypt an array of blocks, in place */
static inline void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
static void ecb_encrypt_blks(block *blks, unsigned nblks, KEY *key) {
while (nblks) {
--nblks;
AES_encrypt((unsigned char *)(blks+nblks), (unsigned char *)(blks+nblks), key);
encrypt(reinterpret_cast<unsigned char *>(blks+nblks), reinterpret_cast<unsigned char *>(blks+nblks), key);
}
}
static inline void AES_ecb_decrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
static void ecb_decrypt_blks(block *blks, unsigned nblks, KEY *key) {
while (nblks) {
--nblks;
AES_decrypt((unsigned char *)(blks+nblks), (unsigned char *)(blks+nblks), key);
decrypt(reinterpret_cast<unsigned char *>(blks+nblks), reinterpret_cast<unsigned char *>(blks+nblks), key);
}
}
} // namespace ocb_aes
#define BPI 4 /* Number of blocks in buffer per ECB call */
/*-------------------*/
@@ -384,12 +406,14 @@ static inline void AES_ecb_decrypt_blks(block *blks, unsigned nblks, AES_KEY *ke
#include <fatal_assert.h>
#include <CommonCrypto/CommonCryptor.h>
namespace ocb_aes {
typedef struct {
CCCryptorRef ref;
uint8_t b[4096];
} AES_KEY;
} KEY;
static inline void AES_set_encrypt_key(unsigned char *handle, const int bits, AES_KEY *key)
static void set_encrypt_key(const unsigned char *handle, const int bits, KEY *key)
{
CCCryptorStatus rv = CCCryptorCreateFromData(
kCCEncrypt,
@@ -405,7 +429,7 @@ static inline void AES_set_encrypt_key(unsigned char *handle, const int bits, AE
fatal_assert(rv == kCCSuccess);
}
static inline void AES_set_decrypt_key(unsigned char *handle, const int bits, AES_KEY *key)
static void set_decrypt_key(const unsigned char *handle, const int bits, KEY *key)
{
CCCryptorStatus rv = CCCryptorCreateFromData(
kCCDecrypt,
@@ -421,7 +445,7 @@ static inline void AES_set_decrypt_key(unsigned char *handle, const int bits, AE
fatal_assert(rv == kCCSuccess);
}
static inline void AES_encrypt(unsigned char *src, unsigned char *dst, AES_KEY *key) {
static void encrypt(unsigned char *src, unsigned char *dst, KEY *key) {
size_t dataOutMoved;
CCCryptorStatus rv = CCCryptorUpdate(
key->ref,
@@ -435,11 +459,11 @@ static inline void AES_encrypt(unsigned char *src, unsigned char *dst, AES_KEY *
}
#if 0
/* unused */
static inline void AES_decrypt(unsigned char *src, unsigned char *dst, AES_KEY *key) {
AES_encrypt(src, dst, key);
static void decrypt(unsigned char *src, unsigned char *dst, KEY *key) {
encrypt(src, dst, key);
}
#endif
static inline void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
static void ecb_encrypt_blks(block *blks, unsigned nblks, KEY *key) {
const size_t dataSize = kCCBlockSizeAES128 * nblks;
size_t dataOutMoved;
CCCryptorStatus rv = CCCryptorUpdate(
@@ -452,10 +476,12 @@ static inline void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *ke
fatal_assert(rv == kCCSuccess);
fatal_assert(dataOutMoved == dataSize);
}
static inline void AES_ecb_decrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
AES_ecb_encrypt_blks(blks, nblks, key);
static void ecb_decrypt_blks(block *blks, unsigned nblks, KEY *key) {
ecb_encrypt_blks(blks, nblks, key);
}
} // namespace ocb_aes
#define BPI 4 /* Number of blocks in buffer per ECB call */
/*-------------------*/
@@ -467,32 +493,36 @@ static inline void AES_ecb_decrypt_blks(block *blks, unsigned nblks, AES_KEY *ke
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#include <nettle/aes.h>
typedef struct aes_ctx AES_KEY;
namespace ocb_aes {
static inline void AES_set_encrypt_key(unsigned char *handle, const int bits, AES_KEY *key)
typedef struct aes_ctx KEY;
static void set_encrypt_key(const unsigned char *handle, const int bits, KEY *key)
{
nettle_aes_set_encrypt_key(key, bits/8, (const uint8_t *)handle);
}
static inline void AES_set_decrypt_key(unsigned char *handle, const int bits, AES_KEY *key)
static void set_decrypt_key(const unsigned char *handle, const int bits, KEY *key)
{
nettle_aes_set_decrypt_key(key, bits/8, (const uint8_t *)handle);
}
static inline void AES_encrypt(unsigned char *src, unsigned char *dst, AES_KEY *key) {
static void encrypt(unsigned char *src, unsigned char *dst, KEY *key) {
nettle_aes_encrypt(key, AES_BLOCK_SIZE, dst, src);
}
#if 0
/* unused */
static inline void AES_decrypt(unsigned char *src, unsigned char *dst, AES_KEY *key) {
static void decrypt(unsigned char *src, unsigned char *dst, KEY *key) {
nettle_aes_decrypt(key, AES_BLOCK_SIZE, dst, src);
}
#endif
static inline void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
static void ecb_encrypt_blks(block *blks, unsigned nblks, KEY *key) {
nettle_aes_encrypt(key, nblks * AES_BLOCK_SIZE, (unsigned char*)blks, (unsigned char*)blks);
}
static inline void AES_ecb_decrypt_blks(block *blks, unsigned nblks, AES_KEY *key) {
static void ecb_decrypt_blks(block *blks, unsigned nblks, KEY *key) {
nettle_aes_decrypt(key, nblks * AES_BLOCK_SIZE, (unsigned char*)blks, (unsigned char*)blks);
}
} // namespace ocb_aes
#define BPI 4 /* Number of blocks in buffer per ECB call */
#pragma GCC diagnostic pop
@@ -527,8 +557,8 @@ struct _ae_ctx {
uint64_t KtopStr[3]; /* Register correct, each item */
uint32_t ad_blocks_processed;
uint32_t blocks_processed;
AES_KEY decrypt_key;
AES_KEY encrypt_key;
ocb_aes::KEY decrypt_key;
ocb_aes::KEY encrypt_key;
#if (OCB_TAG_LEN == 0)
unsigned tag_len;
#endif
@@ -590,16 +620,16 @@ int ae_init(ae_ctx *ctx, const void *key, int key_len, int nonce_len, int tag_le
#if (OCB_KEY_LEN > 0)
key_len = OCB_KEY_LEN;
#endif
AES_set_encrypt_key((unsigned char *)key, key_len*8, &ctx->encrypt_key);
AES_set_decrypt_key((unsigned char *)key, (int)(key_len*8), &ctx->decrypt_key);
ocb_aes::set_encrypt_key(reinterpret_cast<const unsigned char *>(key), key_len*8, &ctx->encrypt_key);
ocb_aes::set_decrypt_key(reinterpret_cast<const unsigned char *>(key), static_cast<int>(key_len*8), &ctx->decrypt_key);
/* Zero things that need zeroing */
ctx->cached_Top = ctx->ad_checksum = zero_block();
ctx->ad_blocks_processed = 0;
/* Compute key-dependent values */
AES_encrypt((unsigned char *)&ctx->cached_Top,
(unsigned char *)&ctx->Lstar, &ctx->encrypt_key);
ocb_aes::encrypt(reinterpret_cast<unsigned char *>(&ctx->cached_Top),
reinterpret_cast<unsigned char *>(&ctx->Lstar), &ctx->encrypt_key);
tmp_blk = swap_if_le(ctx->Lstar);
tmp_blk = double_block(tmp_blk);
ctx->Ldollar = swap_if_le(tmp_blk);
@@ -636,7 +666,7 @@ static block gen_offset_from_nonce(ae_ctx *ctx, const void *nonce)
tmp.u8[15] = tmp.u8[15] & 0xc0; /* Zero low 6 bits of nonce */
if ( unequal_blocks(tmp.bl,ctx->cached_Top) ) { /* Cached? */
ctx->cached_Top = tmp.bl; /* Update cache, KtopStr */
AES_encrypt(tmp.u8, (unsigned char *)&ctx->KtopStr, &ctx->encrypt_key);
ocb_aes::encrypt(tmp.u8, (unsigned char *)&ctx->KtopStr, &ctx->encrypt_key);
if (little.endian) { /* Make Register Correct */
ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
@@ -684,7 +714,7 @@ static void process_ad(ae_ctx *ctx, const void *ad, int ad_len, int final)
ad_offset = xor_block(oa[6], getL(ctx, tz));
ta[7] = xor_block(ad_offset, adp[7]);
#endif
AES_ecb_encrypt_blks(ta,BPI,&ctx->encrypt_key);
ocb_aes::ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
ad_checksum = xor_block(ad_checksum, ta[0]);
ad_checksum = xor_block(ad_checksum, ta[1]);
ad_checksum = xor_block(ad_checksum, ta[2]);
@@ -745,7 +775,7 @@ static void process_ad(ae_ctx *ctx, const void *ad, int ad_len, int final)
ta[k] = xor_block(ad_offset, tmp.bl);
++k;
}
AES_ecb_encrypt_blks(ta,k,&ctx->encrypt_key);
ocb_aes::ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
switch (k) {
#if (BPI == 8)
case 8: ad_checksum = xor_block(ad_checksum, ta[7]);
@@ -842,7 +872,7 @@ int ae_encrypt(ae_ctx * ctx,
ta[7] = xor_block(oa[7], ptp[7]);
checksum = xor_block(checksum, ptp[7]);
#endif
AES_ecb_encrypt_blks(ta,BPI,&ctx->encrypt_key);
ocb_aes::ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
ctp[0] = xor_block(ta[0], oa[0]);
ctp[1] = xor_block(ta[1], oa[1]);
ctp[2] = xor_block(ta[2], oa[2]);
@@ -914,7 +944,7 @@ int ae_encrypt(ae_ctx * ctx,
}
offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
ta[k] = xor_block(offset, checksum); /* Part of tag gen */
AES_ecb_encrypt_blks(ta,k+1,&ctx->encrypt_key);
ocb_aes::ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
if (remaining) {
--k;
@@ -1055,7 +1085,7 @@ int ae_decrypt(ae_ctx *ctx,
oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
ta[7] = xor_block(oa[7], ctp[7]);
#endif
AES_ecb_decrypt_blks(ta,BPI,&ctx->decrypt_key);
ocb_aes::ecb_decrypt_blks(ta,BPI,&ctx->decrypt_key);
ptp[0] = xor_block(ta[0], oa[0]);
checksum = xor_block(checksum, ptp[0]);
ptp[1] = xor_block(ta[1], oa[1]);
@@ -1120,7 +1150,7 @@ int ae_decrypt(ae_ctx *ctx,
if (remaining) {
block pad;
offset = xor_block(offset,ctx->Lstar);
AES_encrypt((unsigned char *)&offset, tmp.u8, &ctx->encrypt_key);
ocb_aes::encrypt(reinterpret_cast<unsigned char *>(&offset), tmp.u8, &ctx->encrypt_key);
pad = tmp.bl;
memcpy(tmp.u8,ctp+k,remaining);
tmp.bl = xor_block(tmp.bl, pad);
@@ -1129,7 +1159,7 @@ int ae_decrypt(ae_ctx *ctx,
checksum = xor_block(checksum, tmp.bl);
}
}
AES_ecb_decrypt_blks(ta,k,&ctx->decrypt_key);
ocb_aes::ecb_decrypt_blks(ta,k,&ctx->decrypt_key);
switch (k) {
#if (BPI == 8)
case 7: ptp[6] = xor_block(ta[6], oa[6]);
@@ -1158,7 +1188,7 @@ int ae_decrypt(ae_ctx *ctx,
/* Calculate expected tag */
offset = xor_block(offset, ctx->Ldollar);
tmp.bl = xor_block(offset, checksum);
AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
ocb_aes::encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
/* Compare with proposed tag, change ct_len if invalid */