Implement decrypting inbound group messages

Includes creation of inbound sessions, etc
This commit is contained in:
Richard van der Hoff 2016-05-18 17:23:09 +01:00
parent 8b1514c0a6
commit 39ad75314b
8 changed files with 480 additions and 6 deletions

View file

@ -32,6 +32,9 @@ enum OlmErrorCode {
OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */
OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */
OLM_BAD_RATCHET_KEY = 11,
OLM_BAD_CHAIN_INDEX = 12,
/* remember to update the list of string constants in error.c when updating /* remember to update the list of string constants in error.c when updating
* this list. */ * this list. */
}; };

View file

@ -0,0 +1,153 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef OLM_INBOUND_GROUP_SESSION_H_
#define OLM_INBOUND_GROUP_SESSION_H_
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct OlmInboundGroupSession OlmInboundGroupSession;
/** get the size of an inbound group session, in bytes. */
size_t olm_inbound_group_session_size();
/**
* Initialise an inbound group session object using the supplied memory
* The supplied memory should be at least olm_inbound_group_session_size()
* bytes.
*/
OlmInboundGroupSession * olm_inbound_group_session(
void *memory
);
/**
* A null terminated string describing the most recent error to happen to a
* group session */
const char *olm_inbound_group_session_last_error(
const OlmInboundGroupSession *session
);
/** Clears the memory used to back this group session */
size_t olm_clear_inbound_group_session(
OlmInboundGroupSession *session
);
/** Returns the number of bytes needed to store an inbound group session */
size_t olm_pickle_inbound_group_session_length(
const OlmInboundGroupSession *session
);
/**
* Stores a group session as a base64 string. Encrypts the session using the
* supplied key. Returns the length of the session on success.
*
* Returns olm_error() on failure. If the pickle output buffer
* is smaller than olm_pickle_inbound_group_session_length() then
* olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL"
*/
size_t olm_pickle_inbound_group_session(
OlmInboundGroupSession *session,
void const * key, size_t key_length,
void * pickled, size_t pickled_length
);
/**
* Loads a group session from a pickled base64 string. Decrypts the session
* using the supplied key.
*
* Returns olm_error() on failure. If the key doesn't match the one used to
* encrypt the account then olm_inbound_group_session_last_error() will be
* "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then
* olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input
* pickled buffer is destroyed
*/
size_t olm_unpickle_inbound_group_session(
OlmInboundGroupSession *session,
void const * key, size_t key_length,
void * pickled, size_t pickled_length
);
/**
* Start a new inbound group session, based on the parameters supplied.
*
* Returns olm_error() on failure. On failure last_error will be set with an
* error code. The last_error will be:
*
* * OLM_INVALID_BASE64 if the session_key is not valid base64
* * OLM_BAD_RATCHET_KEY if the session_key is invalid
*/
size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
/* base64-encoded key */
uint8_t const * session_key, size_t session_key_length
);
/**
* Get an upper bound on the number of bytes of plain-text the decrypt method
* will write for a given input message length. The actual size could be
* different due to padding.
*
* The input message buffer is destroyed.
*
* Returns olm_error() on failure.
*/
size_t olm_group_decrypt_max_plaintext_length(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length
);
/**
* Decrypt a message.
*
* The input message buffer is destroyed.
*
* Returns the length of the decrypted plain-text, or olm_error() on failure.
*
* On failure last_error will be set with an error code. The last_error will
* be:
* * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small
* * OLM_INVALID_BASE64 if the message is not valid base-64
* * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported
* version of the protocol
* * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded
* * OLM_BAD_MESSAGE_MAC if the message could not be verified
* * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the
* message's index (ie, it was sent before the ratchet key was shared with
* us)
*/
size_t olm_group_decrypt(
OlmInboundGroupSession *session,
/* input; note that it will be overwritten with the base64-decoded
message. */
uint8_t * message, size_t message_length,
/* output */
uint8_t * plaintext, size_t max_plaintext_length
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif /* OLM_INBOUND_GROUP_SESSION_H_ */

View file

@ -65,6 +65,30 @@ void _olm_encode_group_message(
); );
struct _OlmDecodeGroupMessageResults {
uint8_t version;
const uint8_t *session_id;
size_t session_id_length;
uint32_t chain_index;
int has_chain_index;
const uint8_t *ciphertext;
size_t ciphertext_length;
};
/**
* Reads the message headers from the input buffer.
*/
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
/* output structure: updated with results */
struct _OlmDecodeGroupMessageResults *results
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif

View file

@ -19,6 +19,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "olm/inbound_group_session.h"
#include "olm/outbound_group_session.h" #include "olm/outbound_group_session.h"
#ifdef __cplusplus #ifdef __cplusplus

199
src/inbound_group_session.c Normal file
View file

@ -0,0 +1,199 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "olm/inbound_group_session.h"
#include <string.h>
#include "olm/base64.h"
#include "olm/cipher.h"
#include "olm/error.h"
#include "olm/megolm.h"
#include "olm/message.h"
#define OLM_PROTOCOL_VERSION 3
struct OlmInboundGroupSession {
/** our earliest known ratchet value */
Megolm initial_ratchet;
/** The most recent ratchet value */
Megolm latest_ratchet;
enum OlmErrorCode last_error;
};
size_t olm_inbound_group_session_size() {
return sizeof(OlmInboundGroupSession);
}
OlmInboundGroupSession * olm_inbound_group_session(
void *memory
) {
OlmInboundGroupSession *session = memory;
olm_clear_inbound_group_session(session);
return session;
}
const char *olm_inbound_group_session_last_error(
const OlmInboundGroupSession *session
) {
return _olm_error_to_string(session->last_error);
}
size_t olm_clear_inbound_group_session(
OlmInboundGroupSession *session
) {
memset(session, 0, sizeof(OlmInboundGroupSession));
return sizeof(OlmInboundGroupSession);
}
size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t * session_key, size_t session_key_length
) {
uint8_t key_buf[MEGOLM_RATCHET_LENGTH];
size_t raw_length = _olm_decode_base64_length(session_key_length);
if (raw_length == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64;
return (size_t)-1;
}
if (raw_length != MEGOLM_RATCHET_LENGTH) {
session->last_error = OLM_BAD_RATCHET_KEY;
return (size_t)-1;
}
_olm_decode_base64(session_key, session_key_length, key_buf);
megolm_init(&session->initial_ratchet, key_buf, message_index);
megolm_init(&session->latest_ratchet, key_buf, message_index);
memset(key_buf, 0, MEGOLM_RATCHET_LENGTH);
return 0;
}
size_t olm_group_decrypt_max_plaintext_length(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length
) {
size_t r;
const struct _olm_cipher *cipher = megolm_cipher();
struct _OlmDecodeGroupMessageResults decoded_results;
r = _olm_decode_base64(message, message_length, message);
if (r == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64;
return r;
}
_olm_decode_group_message(
message, message_length,
cipher->ops->mac_length(cipher),
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
session->last_error = OLM_BAD_MESSAGE_VERSION;
return (size_t)-1;
}
if (!decoded_results.ciphertext) {
session->last_error = OLM_BAD_MESSAGE_FORMAT;
return (size_t)-1;
}
return cipher->ops->decrypt_max_plaintext_length(
cipher, decoded_results.ciphertext_length);
}
size_t olm_group_decrypt(
OlmInboundGroupSession *session,
uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length
) {
struct _OlmDecodeGroupMessageResults decoded_results;
const struct _olm_cipher *cipher = megolm_cipher();
size_t max_length, raw_message_length, r;
Megolm *megolm;
Megolm tmp_megolm;
raw_message_length = _olm_decode_base64(message, message_length, message);
if (raw_message_length == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64;
return (size_t)-1;
}
_olm_decode_group_message(
message, raw_message_length,
cipher->ops->mac_length(cipher),
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
session->last_error = OLM_BAD_MESSAGE_VERSION;
return (size_t)-1;
}
if (!decoded_results.has_chain_index || !decoded_results.session_id
|| !decoded_results.ciphertext
) {
session->last_error = OLM_BAD_MESSAGE_FORMAT;
return (size_t)-1;
}
max_length = cipher->ops->decrypt_max_plaintext_length(
cipher,
decoded_results.ciphertext_length
);
if (max_plaintext_length < max_length) {
session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
return (size_t)-1;
}
/* pick a megolm instance to use. If we're at or beyond the latest ratchet
* value, use that */
if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) {
megolm = &session->latest_ratchet;
} else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) {
/* the counter is before our intial ratchet - we can't decode this. */
session->last_error = OLM_BAD_CHAIN_INDEX;
return (size_t)-1;
} else {
/* otherwise, start from the initial megolm. Take a copy so that we
* don't overwrite the initial megolm */
tmp_megolm = session->initial_ratchet;
megolm = &tmp_megolm;
}
megolm_advance_to(megolm, decoded_results.chain_index);
/* now try checking the mac, and decrypting */
r = cipher->ops->decrypt(
cipher,
megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH,
message, raw_message_length,
decoded_results.ciphertext, decoded_results.ciphertext_length,
plaintext, max_plaintext_length
);
memset(&tmp_megolm, 0, sizeof(tmp_megolm));
if (r == (size_t)-1) {
session->last_error = OLM_BAD_MESSAGE_MAC;
return r;
}
return r;
}

View file

@ -363,3 +363,45 @@ void _olm_encode_group_message(
pos = encode(pos, COUNTER_TAG, chain_index); pos = encode(pos, COUNTER_TAG, chain_index);
pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
} }
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
struct _OlmDecodeGroupMessageResults *results
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length;
std::uint8_t const * unknown = nullptr;
results->session_id = nullptr;
results->session_id_length = 0;
bool has_chain_index = false;
results->chain_index = 0;
results->ciphertext = nullptr;
results->ciphertext_length = 0;
if (pos == end) return;
if (input_length < mac_length) return;
results->version = *(pos++);
while (pos != end) {
pos = decode(
pos, end, GROUP_SESSION_ID_TAG,
results->session_id, results->session_id_length
);
pos = decode(
pos, end, COUNTER_TAG,
results->chain_index, has_chain_index
);
pos = decode(
pos, end, CIPHERTEXT_TAG,
results->ciphertext, results->ciphertext_length
);
if (unknown == pos) {
pos = skip_unknown(pos, end);
}
unknown = pos;
}
results->has_chain_index = (int)has_chain_index;
}

View file

@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "olm/inbound_group_session.h"
#include "olm/outbound_group_session.h" #include "olm/outbound_group_session.h"
#include "unittest.hh" #include "unittest.hh"
@ -19,11 +20,10 @@
int main() { int main() {
{ {
TestCase test_case("Pickle outbound group"); TestCase test_case("Pickle outbound group");
size_t size = olm_outbound_group_session_size(); size_t size = olm_outbound_group_session_size();
void *memory = alloca(size); uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory); OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
size_t pickle_length = olm_pickle_outbound_group_session_length(session); size_t pickle_length = olm_pickle_outbound_group_session_length(session);
@ -61,9 +61,9 @@ int main() {
"0123456789ABDEF0123456789ABCDEF"; "0123456789ABDEF0123456789ABCDEF";
/* build the outbound session */
size_t size = olm_outbound_group_session_size(); size_t size = olm_outbound_group_session_size();
void *memory = alloca(size); uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory); OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
assert_equals((size_t)132, assert_equals((size_t)132,
@ -73,18 +73,48 @@ int main() {
session, random_bytes, sizeof(random_bytes)); session, random_bytes, sizeof(random_bytes));
assert_equals((size_t)0, res); assert_equals((size_t)0, res);
assert_equals(0U, olm_outbound_group_session_message_index(session));
size_t session_key_len = olm_outbound_group_session_key_length(session);
uint8_t session_key[session_key_len];
olm_outbound_group_session_key(session, session_key, session_key_len);
/* encode the message */
uint8_t plaintext[] = "Message"; uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1; size_t plaintext_length = sizeof(plaintext) - 1;
size_t msglen = olm_group_encrypt_message_length( size_t msglen = olm_group_encrypt_message_length(
session, plaintext_length); session, plaintext_length);
uint8_t *msg = (uint8_t *)alloca(msglen); uint8_t msg[msglen];
res = olm_group_encrypt(session, plaintext, plaintext_length, res = olm_group_encrypt(session, plaintext, plaintext_length,
msg, msglen); msg, msglen);
assert_equals(msglen, res); assert_equals(msglen, res);
assert_equals(1U, olm_outbound_group_session_message_index(session));
// TODO: decode the message
/* build the inbound session */
size = olm_inbound_group_session_size();
uint8_t inbound_session_memory[size];
OlmInboundGroupSession *inbound_session =
olm_inbound_group_session(inbound_session_memory);
res = olm_init_inbound_group_session(
inbound_session, 0U, session_key, session_key_len);
assert_equals((size_t)0, res);
/* decode the message */
/* olm_group_decrypt_max_plaintext_length destroys the input so we have to
copy it. */
uint8_t msgcopy[msglen];
memcpy(msgcopy, msg, msglen);
size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
uint8_t plaintext_buf[size];
res = olm_group_decrypt(inbound_session, msg, msglen,
plaintext_buf, size);
assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res);
} }
} }

View file

@ -97,4 +97,26 @@ assert_equals(message2, output, 35);
assert_equals(output+sizeof(expected)-1, ciphertext_ptr); assert_equals(output+sizeof(expected)-1, ciphertext_ptr);
} /* group message encode test */ } /* group message encode test */
{
TestCase test_case("Group message decode test");
struct _OlmDecodeGroupMessageResults results;
std::uint8_t message[] =
"\x03"
"\x2A\x09sessionid"
"\x10\xc8\x01"
"\x22\x0A" "ciphertext"
"hmacsha2";
const uint8_t expected_session_id[] = "sessionid";
_olm_decode_group_message(message, sizeof(message)-1, 8, &results);
assert_equals(std::uint8_t(3), results.version);
assert_equals(std::size_t(9), results.session_id_length);
assert_equals(expected_session_id, results.session_id, 9);
assert_equals(1, results.has_chain_index);
assert_equals(std::uint32_t(200), results.chain_index);
assert_equals(std::size_t(10), results.ciphertext_length);
assert_equals(ciphertext, results.ciphertext, 10);
} /* group message decode test */
} }