From 89d9b972a6d629648d18f4227a08596c65c3894d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 16 Jul 2015 10:45:10 +0100 Subject: [PATCH] Add versions of olm_session_create_inbound and olm_session_matches_inbound which take the curve25519 identity key of the remote device we think the message is from as an additional argument --- include/olm/olm.hh | 29 ++++++++++++++++++++ include/olm/session.hh | 2 ++ src/olm.cpp | 61 ++++++++++++++++++++++++++++++++++++++++-- src/session.cpp | 38 ++++++++++++++++++++------ 4 files changed, 120 insertions(+), 10 deletions(-) diff --git a/include/olm/olm.hh b/include/olm/olm.hh index 64454e6..f08fb9f 100644 --- a/include/olm/olm.hh +++ b/include/olm/olm.hh @@ -242,6 +242,21 @@ size_t olm_create_inbound_session( void * one_time_key_message, size_t message_length ); +/** Create a new in-bound session for sending/receiving messages from an + * incoming PRE_KEY message. Returns olm_error() on failure. If the base64 + * couldn't be decoded then olm_session_last_error will be "INVALID_BASE64". + * If the message was for an unsupported protocol version then + * olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message + * couldn't be decoded then then olm_session_last_error() will be + * "BAD_MESSAGE_FORMAT". If the message refers to an unknown one time + * key then olm_session_last_error() will be "BAD_MESSAGE_KEY_ID". */ +size_t olm_create_inbound_session_from( + OlmSession * session, + OlmAccount * account, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +); + /** Checks if the PRE_KEY message is for this in-bound session. This can happen * if multiple messages are sent to this account before this account sends a * message in reply. Returns olm_error() on failure. If the base64 @@ -255,6 +270,20 @@ size_t olm_matches_inbound_session( void * one_time_key_message, size_t message_length ); +/** Checks if the PRE_KEY message is for this in-bound session. This can happen + * if multiple messages are sent to this account before this account sends a + * message in reply. Returns olm_error() on failure. If the base64 + * couldn't be decoded then olm_session_last_error will be "INVALID_BASE64". + * If the message was for an unsupported protocol version then + * olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message + * couldn't be decoded then then olm_session_last_error() will be + * "BAD_MESSAGE_FORMAT". */ +size_t olm_matches_inbound_session_from( + OlmSession * session, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +); + /** Removes the one time keys that the session used from the account. Returns * olm_error() on failure. If the account doesn't have any matching one time * keys then olm_account_last_error() will be "BAD_MESSAGE_KEY_ID". */ diff --git a/include/olm/session.hh b/include/olm/session.hh index 125df68..b70ce6a 100644 --- a/include/olm/session.hh +++ b/include/olm/session.hh @@ -50,10 +50,12 @@ struct Session { std::size_t new_inbound_session( Account & local_account, + Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ); bool matches_inbound_session( + Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ); diff --git a/src/olm.cpp b/src/olm.cpp index b121ec7..17461fe 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -518,7 +518,36 @@ size_t olm_create_inbound_session( return std::size_t(-1); } return from_c(session)->new_inbound_session( - *from_c(account), from_c(one_time_key_message), raw_length + *from_c(account), nullptr, from_c(one_time_key_message), raw_length + ); +} + + +size_t olm_create_inbound_session_from( + OlmSession * session, + OlmAccount * account, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +) { + if (olm::decode_base64_length(their_identity_key_length) != 32) { + from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64; + return std::size_t(-1); + } + olm::Curve25519PublicKey identity_key; + olm::decode_base64( + from_c(their_identity_key), their_identity_key_length, + identity_key.public_key + ); + + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + return from_c(session)->new_inbound_session( + *from_c(account), &identity_key, + from_c(one_time_key_message), raw_length ); } @@ -534,7 +563,35 @@ size_t olm_matches_inbound_session( return std::size_t(-1); } bool matches = from_c(session)->matches_inbound_session( - from_c(one_time_key_message), raw_length + nullptr, from_c(one_time_key_message), raw_length + ); + return matches ? 1 : 0; +} + + +size_t olm_matches_inbound_session_from( + OlmSession * session, + void const * their_identity_key, size_t their_identity_key_length, + void * one_time_key_message, size_t message_length +) { + if (olm::decode_base64_length(their_identity_key_length) != 32) { + from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64; + return std::size_t(-1); + } + olm::Curve25519PublicKey identity_key; + olm::decode_base64( + from_c(their_identity_key), their_identity_key_length, + identity_key.public_key + ); + + std::size_t raw_length = b64_input( + from_c(one_time_key_message), message_length, from_c(session)->last_error + ); + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + bool matches = from_c(session)->matches_inbound_session( + &identity_key, from_c(one_time_key_message), raw_length ); return matches ? 1 : 0; } diff --git a/src/session.cpp b/src/session.cpp index 654cf1f..0249e6c 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session( namespace { bool check_message_fields( - olm::PreKeyMessageReader & reader + olm::PreKeyMessageReader & reader, bool have_their_identity_key ) { bool ok = true; - ok = ok && reader.identity_key; - ok = ok && reader.identity_key_length == KEY_LENGTH; + ok = ok && (have_their_identity_key || reader.identity_key); + if (reader.identity_key) { + ok = ok && reader.identity_key_length == KEY_LENGTH; + } ok = ok && reader.message; ok = ok && reader.base_key; ok = ok && reader.base_key_length == KEY_LENGTH; @@ -120,16 +122,27 @@ bool check_message_fields( std::size_t olm::Session::new_inbound_session( olm::Account & local_account, + olm::Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ) { olm::PreKeyMessageReader reader; decode_one_time_key_message(reader, one_time_key_message, message_length); - if (!check_message_fields(reader)) { + if (!check_message_fields(reader, their_identity_key)) { last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT; return std::size_t(-1); } + if (reader.identity_key && their_identity_key) { + bool same = 0 == std::memcmp( + their_identity_key->public_key, reader.identity_key, KEY_LENGTH + ); + if (!same) { + last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID; + return std::size_t(-1); + } + } + olm::MessageReader message_reader; decode_message( message_reader, reader.message, reader.message_length, @@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session( bool olm::Session::matches_inbound_session( + olm::Curve25519PublicKey const * their_identity_key, std::uint8_t const * one_time_key_message, std::size_t message_length ) { olm::PreKeyMessageReader reader; decode_one_time_key_message(reader, one_time_key_message, message_length); - if (!check_message_fields(reader)) { + if (!check_message_fields(reader, their_identity_key)) { return false; } bool same = true; - same = same && 0 == std::memcmp( - reader.identity_key, alice_identity_key.public_key, KEY_LENGTH - ); + if (reader.identity_key) { + same = same && 0 == std::memcmp( + reader.identity_key, alice_identity_key.public_key, KEY_LENGTH + ); + } + if (their_identity_key) { + same = same && 0 == std::memcmp( + their_identity_key->public_key, alice_identity_key.public_key, + KEY_LENGTH + ); + } same = same && 0 == std::memcmp( reader.base_key, alice_base_key.public_key, KEY_LENGTH );