From 631f0505540ba12bd4f639d703b1eddd8d4df001 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Nov 2021 17:16:37 -0500 Subject: [PATCH] add a test for fallback keys, and clear memory when we forget the old fallback --- src/account.cpp | 1 + tests/test_olm.cpp | 211 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) diff --git a/src/account.cpp b/src/account.cpp index 5056d5f..41b7188 100644 --- a/src/account.cpp +++ b/src/account.cpp @@ -402,6 +402,7 @@ void olm::Account::forget_old_fallback_key( ) { if (num_fallback_keys >= 2) { num_fallback_keys = 1; + olm::unset(&prev_fallback_key, sizeof(prev_fallback_key)); } } diff --git a/tests/test_olm.cpp b/tests/test_olm.cpp index 30328ed..ae644e5 100644 --- a/tests/test_olm.cpp +++ b/tests/test_olm.cpp @@ -432,6 +432,217 @@ for (unsigned i = 0; i < 8; ++i) { } } +{ /** Fallback key test */ + +TestCase test_case("Fallback key test"); +MockRandom mock_random_a('A', 0x00); +MockRandom mock_random_b('B', 0x80); + +// create a and b acconuts +std::vector a_account_buffer(::olm_account_size()); +::OlmAccount *a_account = ::olm_account(a_account_buffer.data()); +std::vector a_random(::olm_create_account_random_length(a_account)); +mock_random_a(a_random.data(), a_random.size()); +::olm_create_account(a_account, a_random.data(), a_random.size()); + +std::vector b_account_buffer(::olm_account_size()); +::OlmAccount *b_account = ::olm_account(b_account_buffer.data()); +std::vector b_random(::olm_create_account_random_length(b_account)); +mock_random_b(b_random.data(), b_random.size()); +::olm_create_account(b_account, b_random.data(), b_random.size()); + +std::vector a_id_keys(::olm_account_identity_keys_length(a_account)); +::olm_account_identity_keys(a_account, a_id_keys.data(), a_id_keys.size()); + +std::vector b_id_keys(::olm_account_identity_keys_length(b_account)); + +// create and fetch a fallback key for b +std::vector f_random(::olm_account_generate_fallback_key_random_length( + b_account +)); +mock_random_b(f_random.data(), f_random.size()); +::olm_account_generate_fallback_key(b_account, f_random.data(), f_random.size()); +std::vector b_fb_key(::olm_account_unpublished_fallback_key_length(b_account)); +::olm_account_identity_keys(b_account, b_id_keys.data(), b_id_keys.size()); +::olm_account_unpublished_fallback_key(b_account, b_fb_key.data(), b_fb_key.size()); + +// start a new olm session and encrypt a message +std::vector a_session1_buffer(::olm_session_size()); +::OlmSession *a_session1 = ::olm_session(a_session1_buffer.data()); +std::vector a_rand(::olm_create_outbound_session_random_length(a_session1)); +mock_random_a(a_rand.data(), a_rand.size()); +assert_not_equals(std::size_t(-1), ::olm_create_outbound_session( + a_session1, a_account, + b_id_keys.data() + 15, 43, // B's curve25519 identity key + b_fb_key.data() + 25, 43, // B's curve25519 one time key + a_rand.data(), a_rand.size() +)); + +std::uint8_t plaintext[] = "Hello, World"; +std::vector message_1(::olm_encrypt_message_length(a_session1, 12)); +std::vector a_message_random(::olm_encrypt_random_length(a_session1)); +mock_random_a(a_message_random.data(), a_message_random.size()); +assert_equals(std::size_t(0), ::olm_encrypt_message_type(a_session1)); +assert_not_equals(std::size_t(-1), ::olm_encrypt( + a_session1, + plaintext, 12, + a_message_random.data(), a_message_random.size(), + message_1.data(), message_1.size() +)); + + +std::vector tmp_message_1(message_1); +std::vector b_session1_buffer(::olm_session_size()); +::OlmSession *b_session1 = ::olm_session(b_session1_buffer.data()); +::olm_create_inbound_session( + b_session1, b_account, tmp_message_1.data(), message_1.size() +); + +// Check that the inbound session matches the message it was created from. +std::memcpy(tmp_message_1.data(), message_1.data(), message_1.size()); +assert_equals(std::size_t(1), ::olm_matches_inbound_session( + b_session1, + tmp_message_1.data(), message_1.size() +)); + +// Check that the inbound session matches the key this message is supposed +// to be from. +std::memcpy(tmp_message_1.data(), message_1.data(), message_1.size()); +assert_equals(std::size_t(1), ::olm_matches_inbound_session_from( + b_session1, + a_id_keys.data() + 15, 43, // A's curve125519 identity key. + tmp_message_1.data(), message_1.size() +)); + +// Check that the inbound session isn't from a different user. +std::memcpy(tmp_message_1.data(), message_1.data(), message_1.size()); +assert_equals(std::size_t(0), ::olm_matches_inbound_session_from( + b_session1, + b_id_keys.data() + 15, 43, // B's curve25519 identity key. + tmp_message_1.data(), message_1.size() +)); + +// Check that we can decrypt the message. +std::memcpy(tmp_message_1.data(), message_1.data(), message_1.size()); +std::vector plaintext_1(::olm_decrypt_max_plaintext_length( + b_session1, 0, tmp_message_1.data(), message_1.size() +)); +std::memcpy(tmp_message_1.data(), message_1.data(), message_1.size()); +assert_equals(std::size_t(12), ::olm_decrypt( + b_session1, 0, + tmp_message_1.data(), message_1.size(), + plaintext_1.data(), plaintext_1.size() +)); + +assert_equals(plaintext, plaintext_1.data(), 12); + +// create a new fallback key for B (the old fallback should still be usable) +mock_random_b(f_random.data(), f_random.size()); +::olm_account_generate_fallback_key(b_account, f_random.data(), f_random.size()); + + +// start another session and encrypt a message +std::vector a_session2_buffer(::olm_session_size()); +::OlmSession *a_session2 = ::olm_session(a_session2_buffer.data()); +mock_random_a(a_rand.data(), a_rand.size()); +assert_not_equals(std::size_t(-1), ::olm_create_outbound_session( + a_session2, a_account, + b_id_keys.data() + 15, 43, // B's curve25519 identity key + b_fb_key.data() + 25, 43, // B's curve25519 one time key + a_rand.data(), a_rand.size() +)); +std::vector message_2(::olm_encrypt_message_length(a_session2, 12)); +mock_random_a(a_message_random.data(), a_message_random.size()); +assert_equals(std::size_t(0), ::olm_encrypt_message_type(a_session2)); +assert_not_equals(std::size_t(-1), ::olm_encrypt( + a_session2, + plaintext, 12, + a_message_random.data(), a_message_random.size(), + message_2.data(), message_2.size() +)); + + +std::vector tmp_message_2(message_2); +std::vector b_session2_buffer(::olm_session_size()); +::OlmSession *b_session2 = ::olm_session(b_session2_buffer.data()); +assert_not_equals(std::size_t(-1), ::olm_create_inbound_session( + b_session2, b_account, tmp_message_2.data(), message_2.size() +)); + +// Check that the inbound session matches the message it was created from. +std::memcpy(tmp_message_2.data(), message_2.data(), message_2.size()); +assert_equals(std::size_t(1), ::olm_matches_inbound_session( + b_session2, + tmp_message_2.data(), message_2.size() +)); + +// Check that the inbound session matches the key this message is supposed +// to be from. +std::memcpy(tmp_message_2.data(), message_2.data(), message_2.size()); +assert_equals(std::size_t(1), ::olm_matches_inbound_session_from( + b_session2, + a_id_keys.data() + 15, 43, // A's curve125519 identity key. + tmp_message_2.data(), message_2.size() +)); + +// Check that the inbound session isn't from a different user. +std::memcpy(tmp_message_2.data(), message_2.data(), message_2.size()); +assert_equals(std::size_t(0), ::olm_matches_inbound_session_from( + b_session2, + b_id_keys.data() + 15, 43, // B's curve25519 identity key. + tmp_message_2.data(), message_2.size() +)); + +// Check that we can decrypt the message. +std::memcpy(tmp_message_2.data(), message_2.data(), message_2.size()); +std::vector plaintext_2(::olm_decrypt_max_plaintext_length( + b_session2, 0, tmp_message_2.data(), message_2.size() +)); +std::memcpy(tmp_message_2.data(), message_2.data(), message_2.size()); +assert_equals(std::size_t(12), ::olm_decrypt( + b_session2, 0, + tmp_message_2.data(), message_2.size(), + plaintext_2.data(), plaintext_2.size() +)); + +assert_equals(plaintext, plaintext_2.data(), 12); + +// forget the old fallback key -- creating a new session should fail +::olm_account_forget_old_fallback_key(b_account); + +std::vector a_session3_buffer(::olm_session_size()); +::OlmSession *a_session3 = ::olm_session(a_session3_buffer.data()); +mock_random_a(a_rand.data(), a_rand.size()); +assert_not_equals(std::size_t(-1), ::olm_create_outbound_session( + a_session3, a_account, + b_id_keys.data() + 15, 43, // B's curve25519 identity key + b_fb_key.data() + 25, 43, // B's curve25519 one time key + a_rand.data(), a_rand.size() +)); + +std::vector message_3(::olm_encrypt_message_length(a_session3, 12)); +mock_random_a(a_message_random.data(), a_message_random.size()); +assert_equals(std::size_t(0), ::olm_encrypt_message_type(a_session3)); +assert_not_equals(std::size_t(-1), ::olm_encrypt( + a_session3, + plaintext, 12, + a_message_random.data(), a_message_random.size(), + message_3.data(), message_3.size() +)); + + +std::vector tmp_message_3(message_3); +std::vector b_session3_buffer(::olm_session_size()); +::OlmSession *b_session3 = ::olm_session(b_session3_buffer.data()); +assert_equals(std::size_t(-1), ::olm_create_inbound_session( + b_session3, b_account, tmp_message_3.data(), message_3.size() +)); +assert_equals( + std::string("BAD_MESSAGE_KEY_ID"), + std::string(::olm_session_last_error(b_session3)) +); +} + { TestCase test_case("Old account (v3) unpickle test");