Merge pull request #31 from matrix-org/markjh/groupmessageindex

Return the message index when decrypting group messages.
This commit is contained in:
Mark Haines 2016-10-21 09:57:42 +01:00 committed by GitHub
commit 5a98012c0d
7 changed files with 35 additions and 14 deletions

View file

@ -140,7 +140,8 @@ size_t olm_group_decrypt(
uint8_t * message, size_t message_length,
/* output */
uint8_t * plaintext, size_t max_plaintext_length
uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
);

View file

@ -403,8 +403,8 @@ DemoUser.prototype.decryptGroup = function(jsonpacket, callback) {
throw new Error("Unknown session id " + session_id);
}
var plaintext = session.decrypt(packet.body);
done(plaintext);
var result = session.decrypt(packet.body);
done(result.plaintext);
}, callback);
};

View file

@ -73,10 +73,12 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
// So we copy the array to a new buffer
var message_buffer = stack(message_array);
var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH);
var message_index = stack(4);
var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])(
this.ptr,
message_buffer, message_array.length,
plaintext_buffer, max_plaintext_length
plaintext_buffer, max_plaintext_length,
message_index
);
// Pointer_stringify requires a null-terminated argument (the optional
@ -86,7 +88,10 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
0, "i8"
);
return Pointer_stringify(plaintext_buffer);
return {
"plaintext": Pointer_stringify(plaintext_buffer),
"message_index": Module['getValue'](message_index, "i32")
}
});
InboundGroupSession.prototype['session_id'] = restore_stack(function() {

View file

@ -328,7 +328,7 @@ def do_group_decrypt(args):
session = InboundGroupSession()
session.unpickle(args.key, read_base64_file(args.session_file))
message = args.message_file.read()
plaintext = session.decrypt(message)
plaintext, message_index = session.decrypt(message)
with open(args.session_file, "wb") as f:
f.write(session.pickle(args.key))
args.plaintext_file.write(plaintext)

View file

@ -43,6 +43,7 @@ inbound_group_session_function(
lib.olm_group_decrypt,
c_void_p, c_size_t, # message
c_void_p, c_size_t, # plaintext
POINTER(c_uint32), # message_index
)
inbound_group_session_function(lib.olm_inbound_group_session_id_length)
@ -82,11 +83,14 @@ class InboundGroupSession(object):
)
plaintext_buffer = create_string_buffer(max_plaintext_length)
message_buffer = create_string_buffer(message)
message_index = c_uint32()
plaintext_length = lib.olm_group_decrypt(
self.ptr, message_buffer, len(message),
plaintext_buffer, max_plaintext_length
plaintext_buffer, max_plaintext_length,
byref(message_index)
)
return plaintext_buffer.raw[:plaintext_length]
return plaintext_buffer.raw[:plaintext_length], message_index
def session_id(self):
id_length = lib.olm_inbound_group_session_id_length(self.ptr)

View file

@ -263,7 +263,8 @@ size_t olm_group_decrypt_max_plaintext_length(
static size_t _decrypt(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length
uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) {
struct _OlmDecodeGroupMessageResults decoded_results;
size_t max_length, r;
@ -286,6 +287,10 @@ static size_t _decrypt(
return (size_t)-1;
}
if (message_index != NULL) {
*message_index = decoded_results.message_index;
}
/* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
@ -349,7 +354,8 @@ static size_t _decrypt(
size_t olm_group_decrypt(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length
uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) {
size_t raw_message_length;
@ -361,7 +367,8 @@ size_t olm_group_decrypt(
return _decrypt(
session, message, raw_message_length,
plaintext, max_plaintext_length
plaintext, max_plaintext_length,
message_index
);
}

View file

@ -161,10 +161,12 @@ int main() {
memcpy(msgcopy, msg, msglen);
size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt(inbound_session, msg, msglen,
plaintext_buf, size);
plaintext_buf, size, &message_index);
assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res);
assert_equals(message_index, uint32_t(0));
}
{
@ -208,9 +210,11 @@ int main() {
memcpy(msgcopy, message, msglen);
uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt(
inbound_session, msgcopy, msglen, plaintext_buf, size
inbound_session, msgcopy, msglen, plaintext_buf, size, &message_index
);
assert_equals(message_index, uint32_t(0));
assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res);
@ -227,7 +231,7 @@ int main() {
memcpy(msgcopy, message, msglen);
res = olm_group_decrypt(
inbound_session, msgcopy, msglen,
plaintext_buf, size
plaintext_buf, size, &message_index
);
assert_equals((size_t)-1, res);
assert_equals(