Convert cipher.hh to plain C

This commit is contained in:
Richard van der Hoff 2016-05-16 16:25:09 +01:00
parent f9139dfa6a
commit 294cf482ea
8 changed files with 263 additions and 210 deletions

134
include/olm/cipher.h Normal file
View file

@ -0,0 +1,134 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef OLM_CIPHER_H_
#define OLM_CIPHER_H_
#include <stdint.h>
#include <stdlib.h>
#ifdef __cplusplus
extern "C" {
#endif
struct olm_cipher;
struct cipher_ops {
/**
* Returns the length of the message authentication code that will be
* appended to the output.
*/
size_t (*mac_length)(const struct olm_cipher *cipher);
/**
* Returns the length of cipher-text for a given length of plain-text.
*/
size_t (*encrypt_ciphertext_length)(const struct olm_cipher *cipher,
size_t plaintext_length);
/*
* Encrypts the plain-text into the output buffer and authenticates the
* contents of the output buffer covering both cipher-text and any other
* associated data in the output buffer.
*
* |---------------------------------------output_length-->|
* output |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns size_t(-1) if the length of the cipher-text or the output
* buffer is too small. Otherwise returns the length of the output buffer.
*/
size_t (*encrypt)(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * plaintext, size_t plaintext_length,
uint8_t * ciphertext, size_t ciphertext_length,
uint8_t * output, size_t output_length
);
/**
* Returns the maximum length of plain-text that a given length of
* cipher-text can contain.
*/
size_t (*decrypt_max_plaintext_length)(
const struct olm_cipher *cipher,
size_t ciphertext_length
);
/**
* Authenticates the input and decrypts the cipher-text into the plain-text
* buffer.
*
* |----------------------------------------input_length-->|
* input |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns size_t(-1) if the length of the plain-text buffer is too
* small or if the authentication check fails. Otherwise returns the length
* of the plain text.
*/
size_t (*decrypt)(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * input, size_t input_length,
uint8_t const * ciphertext, size_t ciphertext_length,
uint8_t * plaintext, size_t max_plaintext_length
);
/** destroy any private data associated with this cipher */
void (*destruct)(struct olm_cipher *cipher);
};
struct olm_cipher {
const struct cipher_ops *ops;
/* cipher-specific fields follow */
};
struct olm_cipher_aes_sha_256 {
struct olm_cipher base_cipher;
uint8_t const * kdf_info;
size_t kdf_info_length;
};
/**
* initialises a cipher type which uses AES256 for encryption and SHA256 for
* authentication.
*
* cipher: structure to be initialised
*
* kdf_info: context string for the HKDF used for deriving the AES256 key, HMAC
* key, and AES IV, from the key material passed to encrypt/decrypt. Note that
* this is NOT copied so must have a lifetime at least as long as the cipher
* instance.
*
* kdf_info_length: length of context string kdf_info
*/
struct olm_cipher *olm_cipher_aes_sha_256_init(
struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length);
#ifdef __cplusplus
} /* extern "C" */
#endif
#endif /* OLM_CIPHER_H_ */

View file

@ -1,132 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef OLM_CIPHER_HH_
#define OLM_CIPHER_HH_
#include <cstdint>
#include <cstddef>
namespace olm {
class Cipher {
public:
virtual ~Cipher();
/**
* Returns the length of the message authentication code that will be
* appended to the output.
*/
virtual std::size_t mac_length() const = 0;
/**
* Returns the length of cipher-text for a given length of plain-text.
*/
virtual std::size_t encrypt_ciphertext_length(
std::size_t plaintext_length
) const = 0;
/*
* Encrypts the plain-text into the output buffer and authenticates the
* contents of the output buffer covering both cipher-text and any other
* associated data in the output buffer.
*
* |---------------------------------------output_length-->|
* output |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns std::size_t(-1) if the length of the cipher-text or the output
* buffer is too small. Otherwise returns the length of the output buffer.
*/
virtual std::size_t encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const = 0;
/**
* Returns the maximum length of plain-text that a given length of
* cipher-text can contain.
*/
virtual std::size_t decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const = 0;
/**
* Authenticates the input and decrypts the cipher-text into the plain-text
* buffer.
*
* |----------------------------------------input_length-->|
* input |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns std::size_t(-1) if the length of the plain-text buffer is too
* small or if the authentication check fails. Otherwise returns the length
* of the plain text.
*/
virtual std::size_t decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const = 0;
};
class CipherAesSha256 : public Cipher {
public:
CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
);
virtual std::size_t mac_length() const;
virtual std::size_t encrypt_ciphertext_length(
std::size_t plaintext_length
) const;
virtual std::size_t encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const;
virtual std::size_t decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const;
virtual std::size_t decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const;
private:
std::uint8_t const * kdf_info;
std::size_t kdf_info_length;
};
} // namespace
#endif /* OLM_CIPHER_HH_ */

View file

@ -17,9 +17,9 @@
#include "olm/list.hh" #include "olm/list.hh"
#include "olm/error.h" #include "olm/error.h"
namespace olm { struct olm_cipher;
class Cipher; namespace olm {
typedef std::uint8_t SharedKey[olm::KEY_LENGTH]; typedef std::uint8_t SharedKey[olm::KEY_LENGTH];
@ -69,14 +69,14 @@ struct Ratchet {
Ratchet( Ratchet(
KdfInfo const & kdf_info, KdfInfo const & kdf_info,
Cipher const & ratchet_cipher olm_cipher const *ratchet_cipher
); );
/** A some strings identifying the application to feed into the KDF. */ /** A some strings identifying the application to feed into the KDF. */
KdfInfo const & kdf_info; KdfInfo const & kdf_info;
/** The AEAD cipher to use for encrypting messages. */ /** The AEAD cipher to use for encrypting messages. */
Cipher const & ratchet_cipher; olm_cipher const *ratchet_cipher;
/** The last error that happened encrypting or decrypting a message. */ /** The last error that happened encrypting or decrypting a message. */
OlmErrorCode last_error; OlmErrorCode last_error;

View file

@ -12,15 +12,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "olm/crypto.hh" #include "olm/crypto.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include <cstring> #include <cstring>
olm::Cipher::~Cipher() {
}
namespace { namespace {
struct DerivedKeys { struct DerivedKeys {
@ -51,41 +47,34 @@ static void derive_keys(
static const std::size_t MAC_LENGTH = 8; static const std::size_t MAC_LENGTH = 8;
} // namespace size_t aes_sha_256_cipher_mac_length(const struct olm_cipher *cipher) {
olm::CipherAesSha256::CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) {
}
std::size_t olm::CipherAesSha256::mac_length() const {
return MAC_LENGTH; return MAC_LENGTH;
} }
size_t aes_sha_256_cipher_encrypt_ciphertext_length(
std::size_t olm::CipherAesSha256::encrypt_ciphertext_length( const struct olm_cipher *cipher, size_t plaintext_length
std::size_t plaintext_length ) {
) const {
return olm::aes_encrypt_cbc_length(plaintext_length); return olm::aes_encrypt_cbc_length(plaintext_length);
} }
size_t aes_sha_256_cipher_encrypt(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * plaintext, size_t plaintext_length,
uint8_t * ciphertext, size_t ciphertext_length,
uint8_t * output, size_t output_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
std::size_t olm::CipherAesSha256::encrypt( if (aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length)
std::uint8_t const * key, std::size_t key_length, < ciphertext_length) {
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const {
if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) {
return std::size_t(-1); return std::size_t(-1);
} }
struct DerivedKeys keys; struct DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH]; std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys); derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
olm::aes_encrypt_cbc( olm::aes_encrypt_cbc(
keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext
@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt(
} }
std::size_t olm::CipherAesSha256::decrypt_max_plaintext_length( size_t aes_sha_256_cipher_decrypt_max_plaintext_length(
std::size_t ciphertext_length const struct olm_cipher *cipher,
) const { size_t ciphertext_length
) {
return ciphertext_length; return ciphertext_length;
} }
std::size_t olm::CipherAesSha256::decrypt( size_t aes_sha_256_cipher_decrypt(
std::uint8_t const * key, std::size_t key_length, const struct olm_cipher *cipher,
std::uint8_t const * input, std::size_t input_length, uint8_t const * key, size_t key_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length, uint8_t const * input, size_t input_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length uint8_t const * ciphertext, size_t ciphertext_length,
) const { uint8_t * plaintext, size_t max_plaintext_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
DerivedKeys keys; DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH]; std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys); derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
crypto_hmac_sha256( crypto_hmac_sha256(
keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac
@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt(
olm::unset(keys); olm::unset(keys);
return plaintext_length; return plaintext_length;
} }
void aes_sha_256_cipher_destruct(struct olm_cipher *cipher) {
}
const cipher_ops aes_sha_256_cipher_ops = {
aes_sha_256_cipher_mac_length,
aes_sha_256_cipher_encrypt_ciphertext_length,
aes_sha_256_cipher_encrypt,
aes_sha_256_cipher_decrypt_max_plaintext_length,
aes_sha_256_cipher_decrypt,
aes_sha_256_cipher_destruct
};
} // namespace
olm_cipher *olm_cipher_aes_sha_256_init(struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length)
{
cipher->base_cipher.ops = &aes_sha_256_cipher_ops;
cipher->kdf_info = kdf_info;
cipher->kdf_info_length = kdf_info_length;
return &(cipher->base_cipher);
}

View file

@ -15,9 +15,9 @@
#include "olm/olm.h" #include "olm/olm.h"
#include "olm/session.hh" #include "olm/session.hh"
#include "olm/account.hh" #include "olm/account.hh"
#include "olm/cipher.h"
#include "olm/utility.hh" #include "olm/utility.hh"
#include "olm/base64.hh" #include "olm/base64.hh"
#include "olm/cipher.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include <new> #include <new>
@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) {
static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle"; static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";
static const olm::CipherAesSha256 PICKLE_CIPHER( const olm_cipher *get_pickle_cipher() {
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1 static olm_cipher *cipher = NULL;
); static olm_cipher_aes_sha_256 PICKLE_CIPHER;
if (!cipher) {
cipher = olm_cipher_aes_sha_256_init(
&PICKLE_CIPHER,
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
);
}
return cipher;
}
std::size_t enc_output_length( std::size_t enc_output_length(
size_t raw_length size_t raw_length
) { ) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); auto *cipher = get_pickle_cipher();
length += PICKLE_CIPHER.mac_length(); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return olm::encode_base64_length(length); return olm::encode_base64_length(length);
} }
@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos(
std::uint8_t * output, std::uint8_t * output,
size_t raw_length size_t raw_length
) { ) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); auto *cipher = get_pickle_cipher();
length += PICKLE_CIPHER.mac_length(); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return output + olm::encode_base64_length(length) - length; return output + olm::encode_base64_length(length) - length;
} }
@ -85,13 +95,15 @@ std::size_t enc_output(
std::uint8_t const * key, std::size_t key_length, std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, size_t raw_length std::uint8_t * output, size_t raw_length
) { ) {
std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length( auto *cipher = get_pickle_cipher();
raw_length std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
cipher, raw_length
); );
std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length(); std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
std::size_t base64_length = olm::encode_base64_length(length); std::size_t base64_length = olm::encode_base64_length(length);
std::uint8_t * raw_output = output + base64_length - length; std::uint8_t * raw_output = output + base64_length - length;
PICKLE_CIPHER.encrypt( cipher->ops->encrypt(
cipher,
key, key_length, key, key_length,
raw_output, raw_length, raw_output, raw_length,
raw_output, ciphertext_length, raw_output, ciphertext_length,
@ -112,8 +124,10 @@ std::size_t enc_input(
return std::size_t(-1); return std::size_t(-1);
} }
olm::decode_base64(input, b64_length, input); olm::decode_base64(input, b64_length, input);
std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length(); auto *cipher = get_pickle_cipher();
std::size_t result = PICKLE_CIPHER.decrypt( std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
std::size_t result = cipher->ops->decrypt(
cipher,
key, key_length, key, key_length,
input, enc_length, input, enc_length,
input, raw_length, input, raw_length,

View file

@ -15,7 +15,7 @@
#include "olm/ratchet.hh" #include "olm/ratchet.hh"
#include "olm/message.hh" #include "olm/message.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "olm/pickle.hh" #include "olm/pickle.hh"
#include <cstring> #include <cstring>
@ -94,12 +94,13 @@ static void create_message_keys(
static std::size_t verify_mac_and_decrypt( static std::size_t verify_mac_and_decrypt(
olm::Cipher const & cipher, olm_cipher const *cipher,
olm::MessageKey const & message_key, olm::MessageKey const & message_key,
olm::MessageReader const & reader, olm::MessageReader const & reader,
std::uint8_t * plaintext, std::size_t max_plaintext_length std::uint8_t * plaintext, std::size_t max_plaintext_length
) { ) {
return cipher.decrypt( return cipher->ops->decrypt(
cipher,
message_key.key, sizeof(message_key.key), message_key.key, sizeof(message_key.key),
reader.input, reader.input_length, reader.input, reader.input_length,
reader.ciphertext, reader.ciphertext_length, reader.ciphertext, reader.ciphertext_length,
@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain(
olm::Ratchet::Ratchet( olm::Ratchet::Ratchet(
olm::KdfInfo const & kdf_info, olm::KdfInfo const & kdf_info,
Cipher const & ratchet_cipher olm_cipher const * ratchet_cipher
) : kdf_info(kdf_info), ) : kdf_info(kdf_info),
ratchet_cipher(ratchet_cipher), ratchet_cipher(ratchet_cipher),
last_error(OlmErrorCode::OLM_SUCCESS) { last_error(OlmErrorCode::OLM_SUCCESS) {
@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length(
if (!sender_chain.empty()) { if (!sender_chain.empty()) {
counter = sender_chain[0].chain_key.index; counter = sender_chain[0].chain_key.index;
} }
std::size_t padded = ratchet_cipher.encrypt_ciphertext_length( std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length plaintext_length
); );
return olm::encode_message_length( return olm::encode_message_length(
counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length() counter, olm::KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
); );
} }
@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt(
create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys); create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys);
advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key); advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key);
std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length( std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length plaintext_length
); );
std::uint32_t counter = keys.index; std::uint32_t counter = keys.index;
@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt(
olm::store_array(writer.ratchet_key, ratchet_key.public_key); olm::store_array(writer.ratchet_key, ratchet_key.public_key);
ratchet_cipher.encrypt( ratchet_cipher->ops->encrypt(
ratchet_cipher,
keys.key, sizeof(keys.key), keys.key, sizeof(keys.key),
plaintext, plaintext_length, plaintext, plaintext_length,
writer.ciphertext, ciphertext_length, writer.ciphertext, ciphertext_length,
@ -484,7 +488,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
) { ) {
olm::MessageReader reader; olm::MessageReader reader;
olm::decode_message( olm::decode_message(
reader, input, input_length, ratchet_cipher.mac_length() reader, input, input_length,
ratchet_cipher->ops->mac_length(ratchet_cipher)
); );
if (!reader.ciphertext) { if (!reader.ciphertext) {
@ -492,7 +497,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
return std::size_t(-1); return std::size_t(-1);
} }
return ratchet_cipher.decrypt_max_plaintext_length(reader.ciphertext_length); return ratchet_cipher->ops->decrypt_max_plaintext_length(
ratchet_cipher, reader.ciphertext_length);
} }
@ -502,7 +508,8 @@ std::size_t olm::Ratchet::decrypt(
) { ) {
olm::MessageReader reader; olm::MessageReader reader;
olm::decode_message( olm::decode_message(
reader, input, input_length, ratchet_cipher.mac_length() reader, input, input_length,
ratchet_cipher->ops->mac_length(ratchet_cipher)
); );
if (reader.version != PROTOCOL_VERSION) { if (reader.version != PROTOCOL_VERSION) {
@ -515,7 +522,8 @@ std::size_t olm::Ratchet::decrypt(
return std::size_t(-1); return std::size_t(-1);
} }
std::size_t max_length = ratchet_cipher.decrypt_max_plaintext_length( std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
ratchet_cipher,
reader.ciphertext_length reader.ciphertext_length
); );

View file

@ -13,7 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "olm/session.hh" #include "olm/session.hh"
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "olm/crypto.hh" #include "olm/crypto.hh"
#include "olm/account.hh" #include "olm/account.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
@ -30,19 +30,27 @@ static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT";
static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET"; static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET";
static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS"; static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS";
static const olm::CipherAesSha256 OLM_CIPHER(
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);
static const olm::KdfInfo OLM_KDF_INFO = { static const olm::KdfInfo OLM_KDF_INFO = {
ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1, ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1,
RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1 RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1
}; };
const olm_cipher *get_cipher() {
static olm_cipher *cipher;
static olm_cipher_aes_sha_256 OLM_CIPHER;
if (!cipher) {
cipher = olm_cipher_aes_sha_256_init(
&OLM_CIPHER,
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
);
}
return cipher;
}
} // namespace } // namespace
olm::Session::Session( olm::Session::Session(
) : ratchet(OLM_KDF_INFO, OLM_CIPHER), ) : ratchet(OLM_KDF_INFO, get_cipher()),
last_error(OlmErrorCode::OLM_SUCCESS), last_error(OlmErrorCode::OLM_SUCCESS),
received_message(false) { received_message(false) {
@ -149,7 +157,7 @@ std::size_t olm::Session::new_inbound_session(
olm::MessageReader message_reader; olm::MessageReader message_reader;
decode_message( decode_message(
message_reader, reader.message, reader.message_length, message_reader, reader.message, reader.message_length,
ratchet.ratchet_cipher.mac_length() ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher)
); );
if (!message_reader.ratchet_key if (!message_reader.ratchet_key

View file

@ -13,7 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "olm/ratchet.hh" #include "olm/ratchet.hh"
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "unittest.hh" #include "unittest.hh"
@ -28,8 +28,9 @@ olm::KdfInfo kdf_info = {
ratchet_info, sizeof(ratchet_info) - 1 ratchet_info, sizeof(ratchet_info) - 1
}; };
olm::CipherAesSha256 cipher( olm_cipher_aes_sha_256 cipher0;
message_info, sizeof(message_info) - 1 olm_cipher *cipher = olm_cipher_aes_sha_256_init(
&cipher0, message_info, sizeof(message_info) - 1
); );
std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF"; std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF";