Remove session_id from group messages
Putting the session_id inside the packed message body makes it hard to extract so that we can decide which session to use. We don't think there is any advantage to having thes sesion_id protected by the HMACs, so we're going to move it to the JSON framing.
This commit is contained in:
parent
ee8172d882
commit
708fddd747
5 changed files with 10 additions and 45 deletions
|
@ -35,7 +35,6 @@ extern "C" {
|
||||||
* The length of the buffer needed to hold a group message.
|
* The length of the buffer needed to hold a group message.
|
||||||
*/
|
*/
|
||||||
size_t _olm_encode_group_message_length(
|
size_t _olm_encode_group_message_length(
|
||||||
size_t group_session_id_length,
|
|
||||||
uint32_t chain_index,
|
uint32_t chain_index,
|
||||||
size_t ciphertext_length,
|
size_t ciphertext_length,
|
||||||
size_t mac_length
|
size_t mac_length
|
||||||
|
@ -45,8 +44,6 @@ size_t _olm_encode_group_message_length(
|
||||||
* Writes the message headers into the output buffer.
|
* Writes the message headers into the output buffer.
|
||||||
*
|
*
|
||||||
* version: version number of the olm protocol
|
* version: version number of the olm protocol
|
||||||
* session_id: group session identifier
|
|
||||||
* session_id_length: length of session_id
|
|
||||||
* message_index: message index
|
* message_index: message index
|
||||||
* ciphertext_length: length of the ciphertext
|
* ciphertext_length: length of the ciphertext
|
||||||
* output: where to write the output. Should be at least
|
* output: where to write the output. Should be at least
|
||||||
|
@ -58,8 +55,6 @@ size_t _olm_encode_group_message_length(
|
||||||
*/
|
*/
|
||||||
size_t _olm_encode_group_message(
|
size_t _olm_encode_group_message(
|
||||||
uint8_t version,
|
uint8_t version,
|
||||||
const uint8_t *session_id,
|
|
||||||
size_t session_id_length,
|
|
||||||
uint32_t message_index,
|
uint32_t message_index,
|
||||||
size_t ciphertext_length,
|
size_t ciphertext_length,
|
||||||
uint8_t *output,
|
uint8_t *output,
|
||||||
|
@ -69,8 +64,6 @@ size_t _olm_encode_group_message(
|
||||||
|
|
||||||
struct _OlmDecodeGroupMessageResults {
|
struct _OlmDecodeGroupMessageResults {
|
||||||
uint8_t version;
|
uint8_t version;
|
||||||
const uint8_t *session_id;
|
|
||||||
size_t session_id_length;
|
|
||||||
uint32_t message_index;
|
uint32_t message_index;
|
||||||
int has_message_index;
|
int has_message_index;
|
||||||
const uint8_t *ciphertext;
|
const uint8_t *ciphertext;
|
||||||
|
|
|
@ -231,9 +231,7 @@ static size_t _decrypt(
|
||||||
return (size_t)-1;
|
return (size_t)-1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!decoded_results.has_message_index || !decoded_results.session_id
|
if (!decoded_results.has_message_index || !decoded_results.ciphertext ) {
|
||||||
|| !decoded_results.ciphertext
|
|
||||||
) {
|
|
||||||
session->last_error = OLM_BAD_MESSAGE_FORMAT;
|
session->last_error = OLM_BAD_MESSAGE_FORMAT;
|
||||||
return (size_t)-1;
|
return (size_t)-1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -328,18 +328,15 @@ void olm::decode_one_time_key_message(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static const std::uint8_t GROUP_SESSION_ID_TAG = 012;
|
static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 010;
|
||||||
static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 020;
|
static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
|
||||||
static const std::uint8_t GROUP_CIPHERTEXT_TAG = 032;
|
|
||||||
|
|
||||||
size_t _olm_encode_group_message_length(
|
size_t _olm_encode_group_message_length(
|
||||||
size_t group_session_id_length,
|
|
||||||
uint32_t message_index,
|
uint32_t message_index,
|
||||||
size_t ciphertext_length,
|
size_t ciphertext_length,
|
||||||
size_t mac_length
|
size_t mac_length
|
||||||
) {
|
) {
|
||||||
size_t length = VERSION_LENGTH;
|
size_t length = VERSION_LENGTH;
|
||||||
length += 1 + varstring_length(group_session_id_length);
|
|
||||||
length += 1 + varint_length(message_index);
|
length += 1 + varint_length(message_index);
|
||||||
length += 1 + varstring_length(ciphertext_length);
|
length += 1 + varstring_length(ciphertext_length);
|
||||||
length += mac_length;
|
length += mac_length;
|
||||||
|
@ -349,19 +346,14 @@ size_t _olm_encode_group_message_length(
|
||||||
|
|
||||||
size_t _olm_encode_group_message(
|
size_t _olm_encode_group_message(
|
||||||
uint8_t version,
|
uint8_t version,
|
||||||
const uint8_t *session_id,
|
|
||||||
size_t session_id_length,
|
|
||||||
uint32_t message_index,
|
uint32_t message_index,
|
||||||
size_t ciphertext_length,
|
size_t ciphertext_length,
|
||||||
uint8_t *output,
|
uint8_t *output,
|
||||||
uint8_t **ciphertext_ptr
|
uint8_t **ciphertext_ptr
|
||||||
) {
|
) {
|
||||||
std::uint8_t * pos = output;
|
std::uint8_t * pos = output;
|
||||||
std::uint8_t * session_id_pos;
|
|
||||||
|
|
||||||
*(pos++) = version;
|
*(pos++) = version;
|
||||||
pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length);
|
|
||||||
std::memcpy(session_id_pos, session_id, session_id_length);
|
|
||||||
pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index);
|
pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index);
|
||||||
pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
|
pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
|
||||||
return pos-output;
|
return pos-output;
|
||||||
|
@ -376,8 +368,6 @@ void _olm_decode_group_message(
|
||||||
std::uint8_t const * end = input + input_length - mac_length;
|
std::uint8_t const * end = input + input_length - mac_length;
|
||||||
std::uint8_t const * unknown = nullptr;
|
std::uint8_t const * unknown = nullptr;
|
||||||
|
|
||||||
results->session_id = nullptr;
|
|
||||||
results->session_id_length = 0;
|
|
||||||
bool has_message_index = false;
|
bool has_message_index = false;
|
||||||
results->message_index = 0;
|
results->message_index = 0;
|
||||||
results->ciphertext = nullptr;
|
results->ciphertext = nullptr;
|
||||||
|
@ -388,10 +378,6 @@ void _olm_decode_group_message(
|
||||||
results->version = *(pos++);
|
results->version = *(pos++);
|
||||||
|
|
||||||
while (pos != end) {
|
while (pos != end) {
|
||||||
pos = decode(
|
|
||||||
pos, end, GROUP_SESSION_ID_TAG,
|
|
||||||
results->session_id, results->session_id_length
|
|
||||||
);
|
|
||||||
pos = decode(
|
pos = decode(
|
||||||
pos, end, GROUP_MESSAGE_INDEX_TAG,
|
pos, end, GROUP_MESSAGE_INDEX_TAG,
|
||||||
results->message_index, has_message_index
|
results->message_index, has_message_index
|
||||||
|
|
|
@ -187,7 +187,7 @@ static size_t raw_message_length(
|
||||||
mac_length = megolm_cipher->ops->mac_length(megolm_cipher);
|
mac_length = megolm_cipher->ops->mac_length(megolm_cipher);
|
||||||
|
|
||||||
return _olm_encode_group_message_length(
|
return _olm_encode_group_message_length(
|
||||||
GROUP_SESSION_ID_LENGTH, session->ratchet.counter,
|
session->ratchet.counter,
|
||||||
ciphertext_length, mac_length);
|
ciphertext_length, mac_length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,7 +220,6 @@ static size_t _encrypt(
|
||||||
*/
|
*/
|
||||||
message_length = _olm_encode_group_message(
|
message_length = _olm_encode_group_message(
|
||||||
OLM_PROTOCOL_VERSION,
|
OLM_PROTOCOL_VERSION,
|
||||||
session->session_id, GROUP_SESSION_ID_LENGTH,
|
|
||||||
session->ratchet.counter,
|
session->ratchet.counter,
|
||||||
ciphertext_length,
|
ciphertext_length,
|
||||||
buffer,
|
buffer,
|
||||||
|
|
|
@ -67,12 +67,8 @@ assert_equals(message2, output, 35);
|
||||||
|
|
||||||
TestCase test_case("Group message encode test");
|
TestCase test_case("Group message encode test");
|
||||||
|
|
||||||
const uint8_t session_id[] = "sessionid";
|
size_t length = _olm_encode_group_message_length(200, 10, 8);
|
||||||
size_t session_id_len = 9;
|
size_t expected_length = 1 + (1+2) + (2+10) + 8;
|
||||||
|
|
||||||
size_t length = _olm_encode_group_message_length(
|
|
||||||
session_id_len, 200, 10, 8);
|
|
||||||
size_t expected_length = 1 + (2+session_id_len) + (1+2) + (2+10) + 8;
|
|
||||||
assert_equals(expected_length, length);
|
assert_equals(expected_length, length);
|
||||||
|
|
||||||
uint8_t output[50];
|
uint8_t output[50];
|
||||||
|
@ -80,7 +76,6 @@ assert_equals(message2, output, 35);
|
||||||
|
|
||||||
_olm_encode_group_message(
|
_olm_encode_group_message(
|
||||||
3,
|
3,
|
||||||
session_id, session_id_len,
|
|
||||||
200, // counter
|
200, // counter
|
||||||
10, // ciphertext length
|
10, // ciphertext length
|
||||||
output,
|
output,
|
||||||
|
@ -89,9 +84,8 @@ assert_equals(message2, output, 35);
|
||||||
|
|
||||||
uint8_t expected[] =
|
uint8_t expected[] =
|
||||||
"\x03"
|
"\x03"
|
||||||
"\x0A\x09sessionid"
|
"\x08\xC8\x01"
|
||||||
"\x10\xC8\x01"
|
"\x12\x0A";
|
||||||
"\x1A\x0A";
|
|
||||||
|
|
||||||
assert_equals(expected, output, sizeof(expected)-1);
|
assert_equals(expected, output, sizeof(expected)-1);
|
||||||
assert_equals(output+sizeof(expected)-1, ciphertext_ptr);
|
assert_equals(output+sizeof(expected)-1, ciphertext_ptr);
|
||||||
|
@ -103,17 +97,12 @@ assert_equals(message2, output, 35);
|
||||||
struct _OlmDecodeGroupMessageResults results;
|
struct _OlmDecodeGroupMessageResults results;
|
||||||
std::uint8_t message[] =
|
std::uint8_t message[] =
|
||||||
"\x03"
|
"\x03"
|
||||||
"\x0A\x09sessionid"
|
"\x08\xC8\x01"
|
||||||
"\x10\xC8\x01"
|
"\x12\x0A" "ciphertext"
|
||||||
"\x1A\x0A" "ciphertext"
|
|
||||||
"hmacsha2";
|
"hmacsha2";
|
||||||
|
|
||||||
const uint8_t expected_session_id[] = "sessionid";
|
|
||||||
|
|
||||||
_olm_decode_group_message(message, sizeof(message)-1, 8, &results);
|
_olm_decode_group_message(message, sizeof(message)-1, 8, &results);
|
||||||
assert_equals(std::uint8_t(3), results.version);
|
assert_equals(std::uint8_t(3), results.version);
|
||||||
assert_equals(std::size_t(9), results.session_id_length);
|
|
||||||
assert_equals(expected_session_id, results.session_id, 9);
|
|
||||||
assert_equals(1, results.has_message_index);
|
assert_equals(1, results.has_message_index);
|
||||||
assert_equals(std::uint32_t(200), results.message_index);
|
assert_equals(std::uint32_t(200), results.message_index);
|
||||||
assert_equals(std::size_t(10), results.ciphertext_length);
|
assert_equals(std::size_t(10), results.ciphertext_length);
|
||||||
|
|
Loading…
Reference in a new issue