diff --git a/include/olm/error.h b/include/olm/error.h index 3f74992..98d2cf5 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,8 +32,12 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ - OLM_BAD_RATCHET_KEY = 11, - OLM_BAD_CHAIN_INDEX = 12, + OLM_BAD_SESSION_KEY = 11, /*!< Attempt to initialise an inbound group + session from an invalid session key */ + OLM_UNKNOWN_MESSAGE_INDEX = 12, /*!< Attempt to decode a message whose + * index is earlier than our earliest + * known session key. + */ /* remember to update the list of string constants in error.c when updating * this list. */ diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h index 4cf4ac4..e24f377 100644 --- a/include/olm/inbound_group_session.h +++ b/include/olm/inbound_group_session.h @@ -91,7 +91,7 @@ size_t olm_unpickle_inbound_group_session( * error code. The last_error will be: * * * OLM_INVALID_BASE64 if the session_key is not valid base64 - * * OLM_BAD_RATCHET_KEY if the session_key is invalid + * * OLM_BAD_SESSION_KEY if the session_key is invalid */ size_t olm_init_inbound_group_session( OlmInboundGroupSession *session, @@ -129,9 +129,9 @@ size_t olm_group_decrypt_max_plaintext_length( * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported * version of the protocol * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded - * * OLM_BAD_MESSAGE_MAC if the message could not be verified - * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the - * message's index (ie, it was sent before the ratchet key was shared with + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_UNKNOWN_MESSAGE_INDEX if we do not have a session key corresponding to the + * message's index (ie, it was sent before the session key was shared with * us) */ size_t olm_group_decrypt( diff --git a/include/olm/message.h b/include/olm/message.h index bd7aec3..cff15f3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -47,7 +47,7 @@ size_t _olm_encode_group_message_length( * version: version number of the olm protocol * session_id: group session identifier * session_id_length: length of session_id - * chain_index: message index + * message_index: message index * ciphertext_length: length of the ciphertext * output: where to write the output. Should be at least * olm_encode_group_message_length() bytes long. @@ -58,7 +58,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -69,8 +69,8 @@ struct _OlmDecodeGroupMessageResults { uint8_t version; const uint8_t *session_id; size_t session_id_length; - uint32_t chain_index; - int has_chain_index; + uint32_t message_index; + int has_message_index; const uint8_t *ciphertext; size_t ciphertext_length; }; diff --git a/src/error.c b/src/error.c index 0690856..bd8a39d 100644 --- a/src/error.c +++ b/src/error.c @@ -27,6 +27,8 @@ static const char * ERRORS[] = { "BAD_ACCOUNT_KEY", "UNKNOWN_PICKLE_VERSION", "CORRUPTED_PICKLE", + "BAD_SESSION_KEY", + "UNKNOWN_MESSAGE_INDEX", }; const char * _olm_error_to_string(enum OlmErrorCode error) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 34908a9..cc6ba5e 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -78,7 +78,7 @@ size_t olm_init_inbound_group_session( } if (raw_length != MEGOLM_RATCHET_LENGTH) { - session->last_error = OLM_BAD_RATCHET_KEY; + session->last_error = OLM_BAD_SESSION_KEY; return (size_t)-1; } @@ -223,7 +223,7 @@ size_t olm_group_decrypt( return (size_t)-1; } - if (!decoded_results.has_chain_index || !decoded_results.session_id + if (!decoded_results.has_message_index || !decoded_results.session_id || !decoded_results.ciphertext ) { session->last_error = OLM_BAD_MESSAGE_FORMAT; @@ -241,11 +241,11 @@ size_t olm_group_decrypt( /* pick a megolm instance to use. If we're at or beyond the latest ratchet * value, use that */ - if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + if ((int32_t)(decoded_results.message_index - session->latest_ratchet.counter) >= 0) { megolm = &session->latest_ratchet; - } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + } else if ((int32_t)(decoded_results.message_index - session->initial_ratchet.counter) < 0) { /* the counter is before our intial ratchet - we can't decode this. */ - session->last_error = OLM_BAD_CHAIN_INDEX; + session->last_error = OLM_UNKNOWN_MESSAGE_INDEX; return (size_t)-1; } else { /* otherwise, start from the initial megolm. Take a copy so that we @@ -254,7 +254,7 @@ size_t olm_group_decrypt( megolm = &tmp_megolm; } - megolm_advance_to(megolm, decoded_results.chain_index); + megolm_advance_to(megolm, decoded_results.message_index); /* now try checking the mac, and decrypting */ r = cipher->ops->decrypt( diff --git a/src/message.cpp b/src/message.cpp index ec44262..ab4300e 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -328,17 +328,19 @@ void olm::decode_one_time_key_message( -static std::uint8_t const GROUP_SESSION_ID_TAG = 052; +static const std::uint8_t GROUP_SESSION_ID_TAG = 012; +static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 020; +static const std::uint8_t GROUP_CIPHERTEXT_TAG = 032; size_t _olm_encode_group_message_length( size_t group_session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, size_t mac_length ) { size_t length = VERSION_LENGTH; length += 1 + varstring_length(group_session_id_length); - length += 1 + varint_length(chain_index); + length += 1 + varint_length(message_index); length += 1 + varstring_length(ciphertext_length); length += mac_length; return length; @@ -349,7 +351,7 @@ void _olm_encode_group_message( uint8_t version, const uint8_t *session_id, size_t session_id_length, - uint32_t chain_index, + uint32_t message_index, size_t ciphertext_length, uint8_t *output, uint8_t **ciphertext_ptr @@ -360,8 +362,8 @@ void _olm_encode_group_message( *(pos++) = version; pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length); std::memcpy(session_id_pos, session_id, session_id_length); - pos = encode(pos, COUNTER_TAG, chain_index); - pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); + pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index); + pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } void _olm_decode_group_message( @@ -375,8 +377,8 @@ void _olm_decode_group_message( results->session_id = nullptr; results->session_id_length = 0; - bool has_chain_index = false; - results->chain_index = 0; + bool has_message_index = false; + results->message_index = 0; results->ciphertext = nullptr; results->ciphertext_length = 0; @@ -390,11 +392,11 @@ void _olm_decode_group_message( results->session_id, results->session_id_length ); pos = decode( - pos, end, COUNTER_TAG, - results->chain_index, has_chain_index + pos, end, GROUP_MESSAGE_INDEX_TAG, + results->message_index, has_message_index ); pos = decode( - pos, end, CIPHERTEXT_TAG, + pos, end, GROUP_CIPHERTEXT_TAG, results->ciphertext, results->ciphertext_length ); if (unknown == pos) { @@ -403,5 +405,5 @@ void _olm_decode_group_message( unknown = pos; } - results->has_chain_index = (int)has_chain_index; + results->has_message_index = (int)has_message_index; } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index 5fec9e0..30c10a0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -89,9 +89,9 @@ assert_equals(message2, output, 35); uint8_t expected[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0a"; + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A"; assert_equals(expected, output, sizeof(expected)-1); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); @@ -103,9 +103,9 @@ assert_equals(message2, output, 35); struct _OlmDecodeGroupMessageResults results; std::uint8_t message[] = "\x03" - "\x2A\x09sessionid" - "\x10\xc8\x01" - "\x22\x0A" "ciphertext" + "\x0A\x09sessionid" + "\x10\xC8\x01" + "\x1A\x0A" "ciphertext" "hmacsha2"; const uint8_t expected_session_id[] = "sessionid"; @@ -114,8 +114,8 @@ assert_equals(message2, output, 35); assert_equals(std::uint8_t(3), results.version); assert_equals(std::size_t(9), results.session_id_length); assert_equals(expected_session_id, results.session_id, 9); - assert_equals(1, results.has_chain_index); - assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(1, results.has_message_index); + assert_equals(std::uint32_t(200), results.message_index); assert_equals(std::size_t(10), results.ciphertext_length); assert_equals(ciphertext, results.ciphertext, 10); } /* group message decode test */