diff --git a/include/olm/megolm.h b/include/olm/megolm.h index 831c6fb..e4e5d0b 100644 --- a/include/olm/megolm.h +++ b/include/olm/megolm.h @@ -49,11 +49,11 @@ typedef struct Megolm { /** - * Get the cipher used in megolm-backed conversations + * The cipher used in megolm-backed conversations * * (AES256 + SHA256, with keys based on an HKDF with info of MEGOLM_KEYS) */ -const struct _olm_cipher *megolm_cipher(); +extern const struct _olm_cipher *megolm_cipher; /** * initialize the megolm ratchet. random_data should be at least diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 6cded75..b6894c1 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -168,7 +168,6 @@ size_t olm_group_decrypt_max_plaintext_length( uint8_t * message, size_t message_length ) { size_t r; - const struct _olm_cipher *cipher = megolm_cipher(); struct _OlmDecodeGroupMessageResults decoded_results; r = _olm_decode_base64(message, message_length, message); @@ -179,7 +178,7 @@ size_t olm_group_decrypt_max_plaintext_length( _olm_decode_group_message( message, message_length, - cipher->ops->mac_length(cipher), + megolm_cipher->ops->mac_length(megolm_cipher), &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -192,8 +191,8 @@ size_t olm_group_decrypt_max_plaintext_length( return (size_t)-1; } - return cipher->ops->decrypt_max_plaintext_length( - cipher, decoded_results.ciphertext_length); + return megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, decoded_results.ciphertext_length); } @@ -203,7 +202,6 @@ size_t olm_group_decrypt( uint8_t * plaintext, size_t max_plaintext_length ) { struct _OlmDecodeGroupMessageResults decoded_results; - const struct _olm_cipher *cipher = megolm_cipher(); size_t max_length, raw_message_length, r; Megolm *megolm; Megolm tmp_megolm; @@ -216,7 +214,7 @@ size_t olm_group_decrypt( _olm_decode_group_message( message, raw_message_length, - cipher->ops->mac_length(cipher), + megolm_cipher->ops->mac_length(megolm_cipher), &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -231,8 +229,8 @@ size_t olm_group_decrypt( return (size_t)-1; } - max_length = cipher->ops->decrypt_max_plaintext_length( - cipher, + max_length = megolm_cipher->ops->decrypt_max_plaintext_length( + megolm_cipher, decoded_results.ciphertext_length ); if (max_plaintext_length < max_length) { @@ -258,8 +256,8 @@ size_t olm_group_decrypt( megolm_advance_to(megolm, decoded_results.message_index); /* now try checking the mac, and decrypting */ - r = cipher->ops->decrypt( - cipher, + r = megolm_cipher->ops->decrypt( + megolm_cipher, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, message, raw_message_length, decoded_results.ciphertext, decoded_results.ciphertext_length, diff --git a/src/megolm.c b/src/megolm.c index 7567894..110f939 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -22,18 +22,9 @@ #include "olm/crypto.h" #include "olm/pickle.h" -const struct _olm_cipher *megolm_cipher() { - static const uint8_t CIPHER_KDF_INFO[] = "MEGOLM_KEYS"; - static struct _olm_cipher *cipher; - static struct _olm_cipher_aes_sha_256 OLM_CIPHER; - if (!cipher) { - cipher = _olm_cipher_aes_sha_256_init( - &OLM_CIPHER, - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 - ); - } - return cipher; -} +static const struct _olm_cipher_aes_sha_256 MEGOLM_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("MEGOLM_KEYS"); +const struct _olm_cipher *megolm_cipher = OLM_CIPHER_BASE(&MEGOLM_CIPHER); /* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet. */ diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index cf7d32c..9f36ad8 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -179,13 +179,12 @@ static size_t raw_message_length( size_t plaintext_length) { size_t ciphertext_length, mac_length; - const struct _olm_cipher *cipher = megolm_cipher(); - ciphertext_length = cipher->ops->encrypt_ciphertext_length( - cipher, plaintext_length + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, plaintext_length ); - mac_length = cipher->ops->mac_length(cipher); + mac_length = megolm_cipher->ops->mac_length(megolm_cipher); return _olm_encode_group_message_length( GROUP_SESSION_ID_LENGTH, session->ratchet.counter, @@ -210,7 +209,6 @@ size_t olm_group_encrypt( size_t rawmsglen; size_t result; uint8_t *ciphertext_ptr, *message_pos; - const struct _olm_cipher *cipher = megolm_cipher(); rawmsglen = raw_message_length(session, plaintext_length); @@ -219,8 +217,8 @@ size_t olm_group_encrypt( return (size_t)-1; } - ciphertext_length = cipher->ops->encrypt_ciphertext_length( - cipher, + ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length( + megolm_cipher, plaintext_length ); @@ -240,8 +238,8 @@ size_t olm_group_encrypt( message_pos, &ciphertext_ptr); - result = cipher->ops->encrypt( - cipher, + result = megolm_cipher->ops->encrypt( + megolm_cipher, megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH, plaintext, plaintext_length, ciphertext_ptr, ciphertext_length,