diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index 4796414..34908a9 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -22,8 +22,12 @@ #include "olm/error.h" #include "olm/megolm.h" #include "olm/message.h" +#include "olm/pickle.h" +#include "olm/pickle_encoding.h" + #define OLM_PROTOCOL_VERSION 3 +#define PICKLE_VERSION 1 struct OlmInboundGroupSession { /** our earliest known ratchet value */ @@ -86,6 +90,78 @@ size_t olm_init_inbound_group_session( return 0; } +static size_t raw_pickle_length( + const OlmInboundGroupSession *session +) { + size_t length = 0; + length += _olm_pickle_uint32_length(PICKLE_VERSION); + length += megolm_pickle_length(&session->initial_ratchet); + length += megolm_pickle_length(&session->latest_ratchet); + return length; +} + +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +) { + return _olm_enc_output_length(raw_pickle_length(session)); +} + +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + size_t raw_length = raw_pickle_length(session); + uint8_t *pos; + + if (pickled_length < _olm_enc_output_length(raw_length)) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + pos = _olm_enc_output_pos(pickled, raw_length); + pos = _olm_pickle_uint32(pos, PICKLE_VERSION); + pos = megolm_pickle(&session->initial_ratchet, pos); + pos = megolm_pickle(&session->latest_ratchet, pos); + + return _olm_enc_output(key, key_length, pickled, raw_length); +} + +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +) { + const uint8_t *pos; + const uint8_t *end; + uint32_t pickle_version; + + size_t raw_length = _olm_enc_input( + key, key_length, pickled, pickled_length, &(session->last_error) + ); + if (raw_length == (size_t)-1) { + return raw_length; + } + + pos = pickled; + end = pos + raw_length; + pos = _olm_unpickle_uint32(pos, end, &pickle_version); + if (pickle_version != PICKLE_VERSION) { + session->last_error = OLM_UNKNOWN_PICKLE_VERSION; + return (size_t)-1; + } + pos = megolm_unpickle(&session->initial_ratchet, pos, end); + pos = megolm_unpickle(&session->latest_ratchet, pos, end); + + if (end != pos) { + /* We had the wrong number of bytes in the input. */ + session->last_error = OLM_CORRUPTED_PICKLE; + return (size_t)-1; + } + + return pickled_length; +} + size_t olm_group_decrypt_max_plaintext_length( OlmInboundGroupSession *session, uint8_t * message, size_t message_length diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 5bbdc9d..4a82154 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -20,7 +20,7 @@ int main() { { - TestCase test_case("Pickle outbound group"); + TestCase test_case("Pickle outbound group session"); size_t size = olm_outbound_group_session_size(); uint8_t memory[size]; @@ -50,6 +50,37 @@ int main() { } +{ + TestCase test_case("Pickle inbound group session"); + + size_t size = olm_inbound_group_session_size(); + uint8_t memory[size]; + OlmInboundGroupSession *session = olm_inbound_group_session(memory); + + size_t pickle_length = olm_pickle_inbound_group_session_length(session); + uint8_t pickle1[pickle_length]; + olm_pickle_inbound_group_session(session, + "secret_key", 10, + pickle1, pickle_length); + uint8_t pickle2[pickle_length]; + memcpy(pickle2, pickle1, pickle_length); + + uint8_t buffer2[size]; + OlmInboundGroupSession *session2 = olm_inbound_group_session(buffer2); + size_t res = olm_unpickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + assert_not_equals((size_t)-1, res); + assert_equals(pickle_length, + olm_pickle_inbound_group_session_length(session2)); + olm_pickle_inbound_group_session(session2, + "secret_key", 10, + pickle2, pickle_length); + + assert_equals(pickle1, pickle2, pickle_length); +} + + { TestCase test_case("Group message send/receive");