Merge pull request #31 from matrix-org/markjh/groupmessageindex

Return the message index when decrypting group messages.
This commit is contained in:
Mark Haines 2016-10-21 09:57:42 +01:00 committed by GitHub
commit 5a98012c0d
7 changed files with 35 additions and 14 deletions

View file

@ -140,7 +140,8 @@ size_t olm_group_decrypt(
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
/* output */ /* output */
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
); );

View file

@ -403,8 +403,8 @@ DemoUser.prototype.decryptGroup = function(jsonpacket, callback) {
throw new Error("Unknown session id " + session_id); throw new Error("Unknown session id " + session_id);
} }
var plaintext = session.decrypt(packet.body); var result = session.decrypt(packet.body);
done(plaintext); done(result.plaintext);
}, callback); }, callback);
}; };

View file

@ -73,10 +73,12 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
// So we copy the array to a new buffer // So we copy the array to a new buffer
var message_buffer = stack(message_array); var message_buffer = stack(message_array);
var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH); var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH);
var message_index = stack(4);
var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])( var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])(
this.ptr, this.ptr,
message_buffer, message_array.length, message_buffer, message_array.length,
plaintext_buffer, max_plaintext_length plaintext_buffer, max_plaintext_length,
message_index
); );
// Pointer_stringify requires a null-terminated argument (the optional // Pointer_stringify requires a null-terminated argument (the optional
@ -86,7 +88,10 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
0, "i8" 0, "i8"
); );
return Pointer_stringify(plaintext_buffer); return {
"plaintext": Pointer_stringify(plaintext_buffer),
"message_index": Module['getValue'](message_index, "i32")
}
}); });
InboundGroupSession.prototype['session_id'] = restore_stack(function() { InboundGroupSession.prototype['session_id'] = restore_stack(function() {

View file

@ -328,7 +328,7 @@ def do_group_decrypt(args):
session = InboundGroupSession() session = InboundGroupSession()
session.unpickle(args.key, read_base64_file(args.session_file)) session.unpickle(args.key, read_base64_file(args.session_file))
message = args.message_file.read() message = args.message_file.read()
plaintext = session.decrypt(message) plaintext, message_index = session.decrypt(message)
with open(args.session_file, "wb") as f: with open(args.session_file, "wb") as f:
f.write(session.pickle(args.key)) f.write(session.pickle(args.key))
args.plaintext_file.write(plaintext) args.plaintext_file.write(plaintext)

View file

@ -43,6 +43,7 @@ inbound_group_session_function(
lib.olm_group_decrypt, lib.olm_group_decrypt,
c_void_p, c_size_t, # message c_void_p, c_size_t, # message
c_void_p, c_size_t, # plaintext c_void_p, c_size_t, # plaintext
POINTER(c_uint32), # message_index
) )
inbound_group_session_function(lib.olm_inbound_group_session_id_length) inbound_group_session_function(lib.olm_inbound_group_session_id_length)
@ -82,11 +83,14 @@ class InboundGroupSession(object):
) )
plaintext_buffer = create_string_buffer(max_plaintext_length) plaintext_buffer = create_string_buffer(max_plaintext_length)
message_buffer = create_string_buffer(message) message_buffer = create_string_buffer(message)
message_index = c_uint32()
plaintext_length = lib.olm_group_decrypt( plaintext_length = lib.olm_group_decrypt(
self.ptr, message_buffer, len(message), self.ptr, message_buffer, len(message),
plaintext_buffer, max_plaintext_length plaintext_buffer, max_plaintext_length,
byref(message_index)
) )
return plaintext_buffer.raw[:plaintext_length] return plaintext_buffer.raw[:plaintext_length], message_index
def session_id(self): def session_id(self):
id_length = lib.olm_inbound_group_session_id_length(self.ptr) id_length = lib.olm_inbound_group_session_id_length(self.ptr)

View file

@ -263,7 +263,8 @@ size_t olm_group_decrypt_max_plaintext_length(
static size_t _decrypt( static size_t _decrypt(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) { ) {
struct _OlmDecodeGroupMessageResults decoded_results; struct _OlmDecodeGroupMessageResults decoded_results;
size_t max_length, r; size_t max_length, r;
@ -286,6 +287,10 @@ static size_t _decrypt(
return (size_t)-1; return (size_t)-1;
} }
if (message_index != NULL) {
*message_index = decoded_results.message_index;
}
/* verify the signature. We could do this before decoding the message, but /* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a * we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION" * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
@ -349,7 +354,8 @@ static size_t _decrypt(
size_t olm_group_decrypt( size_t olm_group_decrypt(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) { ) {
size_t raw_message_length; size_t raw_message_length;
@ -361,7 +367,8 @@ size_t olm_group_decrypt(
return _decrypt( return _decrypt(
session, message, raw_message_length, session, message, raw_message_length,
plaintext, max_plaintext_length plaintext, max_plaintext_length,
message_index
); );
} }

View file

@ -161,10 +161,12 @@ int main() {
memcpy(msgcopy, msg, msglen); memcpy(msgcopy, msg, msglen);
size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
uint8_t plaintext_buf[size]; uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt(inbound_session, msg, msglen, res = olm_group_decrypt(inbound_session, msg, msglen,
plaintext_buf, size); plaintext_buf, size, &message_index);
assert_equals(plaintext_length, res); assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res); assert_equals(plaintext, plaintext_buf, res);
assert_equals(message_index, uint32_t(0));
} }
{ {
@ -208,9 +210,11 @@ int main() {
memcpy(msgcopy, message, msglen); memcpy(msgcopy, message, msglen);
uint8_t plaintext_buf[size]; uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt( res = olm_group_decrypt(
inbound_session, msgcopy, msglen, plaintext_buf, size inbound_session, msgcopy, msglen, plaintext_buf, size, &message_index
); );
assert_equals(message_index, uint32_t(0));
assert_equals(plaintext_length, res); assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res); assert_equals(plaintext, plaintext_buf, res);
@ -227,7 +231,7 @@ int main() {
memcpy(msgcopy, message, msglen); memcpy(msgcopy, message, msglen);
res = olm_group_decrypt( res = olm_group_decrypt(
inbound_session, msgcopy, msglen, inbound_session, msgcopy, msglen,
plaintext_buf, size plaintext_buf, size, &message_index
); );
assert_equals((size_t)-1, res); assert_equals((size_t)-1, res);
assert_equals( assert_equals(