Add encoder and decoder for PreKey messages

This commit is contained in:
Mark Haines 2015-06-11 15:57:45 +01:00
parent 816435a860
commit e44c82a7b4
4 changed files with 265 additions and 66 deletions

View file

@ -37,6 +37,7 @@ struct MessageWriter {
struct MessageReader { struct MessageReader {
std::uint8_t version; std::uint8_t version;
bool has_counter;
std::uint32_t counter; std::uint32_t counter;
std::uint8_t const * input; std::size_t input_length; std::uint8_t const * input; std::size_t input_length;
std::uint8_t const * ratchet_key; std::size_t ratchet_key_length; std::uint8_t const * ratchet_key; std::size_t ratchet_key_length;
@ -46,9 +47,8 @@ struct MessageReader {
/** /**
* Writes the message headers into the output buffer. * Writes the message headers into the output buffer.
* Returns a writer struct populated with pointers into the output buffer. * Populates the writer struct with pointers into the output buffer.
*/ */
void encode_message( void encode_message(
MessageWriter & writer, MessageWriter & writer,
std::uint8_t version, std::uint8_t version,
@ -62,13 +62,67 @@ void encode_message(
/** /**
* Reads the message headers from the input buffer. * Reads the message headers from the input buffer.
* Populates the reader struct with pointers into the input buffer. * Populates the reader struct with pointers into the input buffer.
* On failure returns std::size_t(-1).
*/ */
std::size_t decode_message( void decode_message(
MessageReader & reader, MessageReader & reader,
std::uint8_t const * input, std::size_t input_length, std::uint8_t const * input, std::size_t input_length,
std::size_t mac_length std::size_t mac_length
); );
struct PreKeyMessageWriter {
std::uint8_t * identity_key;
std::uint8_t * base_key;
std::uint8_t * message;
};
struct PreKeyMessageReader {
std::uint8_t version;
bool has_registration_id;
bool has_one_time_key_id;
std::uint32_t registration_id;
std::uint32_t one_time_key_id;
std::uint8_t const * identity_key; std::size_t identity_key_length;
std::uint8_t const * base_key; std::size_t base_key_length;
std::uint8_t const * message; std::size_t message_length;
};
/**
* The length of the buffer needed to hold a message.
*/
std::size_t encode_one_time_key_message_length(
std::uint32_t registration_id,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length
);
/**
* Writes the message headers into the output buffer.
* Populates the writer struct with pointers into the output buffer.
*/
void encode_one_time_key_message(
PreKeyMessageWriter & writer,
std::uint8_t version,
std::uint32_t registration_id,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length,
std::uint8_t * output
);
/**
* Reads the message headers from the input buffer.
* Populates the reader struct with pointers into the input buffer.
*/
void decode_one_time_key_message(
PreKeyMessageReader & reader,
std::uint8_t const * input, std::size_t input_length
);
} // namespace axolotl } // namespace axolotl

View file

@ -81,6 +81,82 @@ static std::uint8_t const RATCHET_KEY_TAG = 012;
static std::uint8_t const COUNTER_TAG = 020; static std::uint8_t const COUNTER_TAG = 020;
static std::uint8_t const CIPHERTEXT_TAG = 042; static std::uint8_t const CIPHERTEXT_TAG = 042;
std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint32_t value
) {
*(pos++) = tag;
return varint_encode(pos, value);
}
std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint8_t * & value, std::size_t value_length
) {
*(pos++) = tag;
pos = varint_encode(pos, value_length);
value = pos;
return pos + value_length;
}
std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint32_t & value, bool & has_value
) {
if (pos != end && *pos == tag) {
++pos;
std::uint8_t const * value_start = pos;
pos = varint_skip(pos, end);
value = varint_decode<std::uint32_t>(value_start, pos);
has_value = true;
}
return pos;
}
std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint8_t const * & value, std::size_t & value_length
) {
if (pos != end && *pos == tag) {
++pos;
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > end - pos) return end;
value = pos;
value_length = len;
pos += len;
}
return pos;
}
std::uint8_t const * skip_unknown(
std::uint8_t const * pos, std::uint8_t const * end
) {
if (pos != end) {
uint8_t tag = *pos;
if (tag & 0x7 == 0) {
pos = varint_skip(pos, end);
pos = varint_skip(pos, end);
} else if (tag & 0x7 == 2) {
pos = varint_skip(pos, end);
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > end - pos) return end;
pos += len;
} else {
return end;
}
}
return pos;
}
} // namespace } // namespace
@ -109,75 +185,138 @@ void axolotl::encode_message(
) { ) {
std::uint8_t * pos = output; std::uint8_t * pos = output;
*(pos++) = version; *(pos++) = version;
*(pos++) = COUNTER_TAG; pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length);
pos = varint_encode(pos, counter); pos = encode(pos, COUNTER_TAG, counter);
*(pos++) = RATCHET_KEY_TAG; pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length);
pos = varint_encode(pos, ratchet_key_length);
writer.ratchet_key = pos;
pos += ratchet_key_length;
*(pos++) = CIPHERTEXT_TAG;
pos = varint_encode(pos, ciphertext_length);
writer.ciphertext = pos;
pos += ciphertext_length;
} }
std::size_t axolotl::decode_message( void axolotl::decode_message(
axolotl::MessageReader & reader, axolotl::MessageReader & reader,
std::uint8_t const * input, std::size_t input_length, std::uint8_t const * input, std::size_t input_length,
std::size_t mac_length std::size_t mac_length
) { ) {
std::uint8_t const * pos = input; std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length; std::uint8_t const * end = input + input_length - mac_length;
std::uint8_t flags = 0; std::uint8_t const * unknown = NULL;
std::size_t result = std::size_t(-1);
if (pos == end) return result; if (pos == end) return;
reader.version = *(pos++); reader.version = *(pos++);
while (pos != end) {
uint8_t tag = *(pos);
if (tag == COUNTER_TAG) {
++pos;
std::uint8_t const * counter_start = pos;
pos = varint_skip(pos, end);
reader.counter = varint_decode<std::uint32_t>(counter_start, pos);
flags |= 1;
} else if (tag == RATCHET_KEY_TAG) {
++pos;
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > end - pos) return result;
reader.ratchet_key_length = len;
reader.ratchet_key = pos;
pos += len;
flags |= 2;
} else if (tag == CIPHERTEXT_TAG) {
++pos;
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > end - pos) return result;
reader.ciphertext_length = len;
reader.ciphertext = pos;
pos += len;
flags |= 4;
} else if (tag & 0x7 == 0) {
pos = varint_skip(pos, end);
pos = varint_skip(pos, end);
} else if (tag & 0x7 == 2) {
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > end - pos) return result;
pos += len;
} else {
return std::size_t(-1);
}
}
if (flags == 0x7) {
reader.input = input; reader.input = input;
reader.input_length = input_length; reader.input_length = input_length;
return std::size_t(pos - input); reader.has_counter = false;
reader.ratchet_key = NULL;
reader.ciphertext = NULL;
while (pos != end) {
pos = decode(
pos, end, RATCHET_KEY_TAG,
reader.ratchet_key, reader.ratchet_key_length
);
pos = decode(
pos, end, COUNTER_TAG,
reader.counter, reader.has_counter
);
pos = decode(
pos, end, CIPHERTEXT_TAG,
reader.ciphertext, reader.ciphertext_length
);
if (unknown == pos) {
pos == skip_unknown(pos, end);
}
unknown = pos;
}
}
namespace {
static std::uint8_t const REGISTRATION_ID_TAG = 050;
static std::uint8_t const ONE_TIME_KEY_ID_TAG = 010;
static std::uint8_t const BASE_KEY_TAG = 022;
static std::uint8_t const IDENTITY_KEY_TAG = 032;
static std::uint8_t const MESSAGE_TAG = 042;
} // namespace
std::size_t axolotl::encode_one_time_key_message_length(
std::uint32_t registration_id,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length
) {
std::size_t length = VERSION_LENGTH;
length += 1 + varint_length(registration_id);
length += 1 + varint_length(one_time_key_id);
length += 1 + varstring_length(identity_key_length);
length += 1 + varstring_length(base_key_length);
length += 1 + varstring_length(message_length);
return length;
}
void axolotl::encode_one_time_key_message(
axolotl::PreKeyMessageWriter & writer,
std::uint8_t version,
std::uint32_t registration_id,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length,
std::uint8_t * output
) {
std::uint8_t * pos = output;
*(pos++) = version;
pos = encode(pos, REGISTRATION_ID_TAG, registration_id);
pos = encode(pos, ONE_TIME_KEY_ID_TAG, one_time_key_id);
pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length);
pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length);
pos = encode(pos, MESSAGE_TAG, writer.message, message_length);
}
void axolotl::decode_one_time_key_message(
PreKeyMessageReader & reader,
std::uint8_t const * input, std::size_t input_length
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length;
std::uint8_t const * unknown = NULL;
if (pos == end) return;
reader.version = *(pos++);
reader.has_registration_id = false;
reader.has_one_time_key_id = false;
reader.identity_key = NULL;
reader.base_key = NULL;
reader.message = NULL;
while (pos != end) {
pos = decode(
pos, end, REGISTRATION_ID_TAG,
reader.registration_id, reader.has_registration_id
);
pos = decode(
pos, end, ONE_TIME_KEY_ID_TAG,
reader.one_time_key_id, reader.has_one_time_key_id
);
pos = decode(
pos, end, BASE_KEY_TAG,
reader.base_key, reader.base_key_length
);
pos = decode(
pos, end, IDENTITY_KEY_TAG,
reader.identity_key, reader.identity_key_length
);
pos = decode(
pos, end, MESSAGE_TAG,
reader.message, reader.message_length
);
if (unknown == pos) {
pos == skip_unknown(pos, end);
}
unknown = pos;
} }
return result;
} }

View file

@ -444,7 +444,7 @@ std::size_t axolotl::Session::decrypt(
} }
axolotl::MessageReader reader; axolotl::MessageReader reader;
std::size_t body_length = axolotl::decode_message( axolotl::decode_message(
reader, input, input_length, ratchet_cipher.mac_length() reader, input, input_length, ratchet_cipher.mac_length()
); );
@ -453,7 +453,12 @@ std::size_t axolotl::Session::decrypt(
return std::size_t(-1); return std::size_t(-1);
} }
if (body_length == size_t(-1) || reader.ratchet_key_length != KEY_LENGTH) { if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) {
last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
if (reader.ratchet_key_length != KEY_LENGTH) {
last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT; last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1); return std::size_t(-1);
} }

View file

@ -17,8 +17,8 @@
int main() { int main() {
std::uint8_t message1[36] = "\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2"; std::uint8_t message1[36] = "\x03\x10\x01\n\nratchetkey\"\nciphertexthmacsha2";
std::uint8_t message2[36] = "\x03\x10\x01\n\nratchetkey\"\nciphertexthmacsha2"; std::uint8_t message2[36] = "\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2";
std::uint8_t ratchetkey[11] = "ratchetkey"; std::uint8_t ratchetkey[11] = "ratchetkey";
std::uint8_t ciphertext[11] = "ciphertext"; std::uint8_t ciphertext[11] = "ciphertext";
std::uint8_t hmacsha2[9] = "hmacsha2"; std::uint8_t hmacsha2[9] = "hmacsha2";
@ -31,6 +31,7 @@ axolotl::MessageReader reader;
axolotl::decode_message(reader, message1, 35, 8); axolotl::decode_message(reader, message1, 35, 8);
assert_equals(std::uint8_t(3), reader.version); assert_equals(std::uint8_t(3), reader.version);
assert_equals(true, reader.has_counter);
assert_equals(std::uint32_t(1), reader.counter); assert_equals(std::uint32_t(1), reader.counter);
assert_equals(std::size_t(10), reader.ratchet_key_length); assert_equals(std::size_t(10), reader.ratchet_key_length);
assert_equals(std::size_t(10), reader.ciphertext_length); assert_equals(std::size_t(10), reader.ciphertext_length);