Fail decoding base64 of invalid length.

olm::decode_base64 now returns the length of the raw decoded data on
success. When given input with an invalid base64 length, it fails early
(before decoding any input) and returns -1.

This also makes the C function _olm_decode_base64 an actual binding of
olm::decode_base64 instead of a wrapper with slightly different
behaviour.
This commit is contained in:
Denis Kasak 2021-05-18 11:38:33 +02:00
parent a5efc08ef3
commit e82f2601b0
3 changed files with 48 additions and 5 deletions

View file

@ -51,8 +51,12 @@ std::size_t decode_base64_length(
* Writes decode_base64_length(input_length) bytes to the output buffer. * Writes decode_base64_length(input_length) bytes to the output buffer.
* The output can overlap with the first three quarters of the input buffer. * The output can overlap with the first three quarters of the input buffer.
* That is, the input pointers and output pointer may be the same. * That is, the input pointers and output pointer may be the same.
*
* Returns the number of bytes of raw data the base64 input decoded to. If the
* input length supplied is not a valid length for base64, returns
* std::size_t(-1) and does not decode.
*/ */
std::uint8_t const * decode_base64( std::size_t decode_base64(
std::uint8_t const * input, std::size_t input_length, std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output std::uint8_t * output
); );

View file

@ -12,6 +12,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <cassert>
#include "olm/base64.h" #include "olm/base64.h"
#include "olm/base64.hh" #include "olm/base64.hh"
@ -101,12 +103,19 @@ std::size_t olm::decode_base64_length(
} }
std::uint8_t const * olm::decode_base64( std::size_t olm::decode_base64(
std::uint8_t const * input, std::size_t input_length, std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output std::uint8_t * output
) { ) {
size_t raw_length = olm::decode_base64_length(input_length);
if (raw_length == std::size_t(-1)) {
return std::size_t(-1);
}
std::uint8_t const * end = input + (input_length / 4) * 4; std::uint8_t const * end = input + (input_length / 4) * 4;
std::uint8_t const * pos = input; std::uint8_t const * pos = input;
while (pos != end) { while (pos != end) {
unsigned value = DECODE_BASE64[pos[0] & 0x7F]; unsigned value = DECODE_BASE64[pos[0] & 0x7F];
value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F];
@ -118,8 +127,19 @@ std::uint8_t const * olm::decode_base64(
value >>= 8; output[0] = value; value >>= 8; output[0] = value;
output += 3; output += 3;
} }
unsigned remainder = input + input_length - pos; unsigned remainder = input + input_length - pos;
if (remainder) { if (remainder) {
/* A base64 payload with a single byte remainder cannot occur because
* a single base64 character only encodes 6 bits, which is less than
* a full byte. Therefore, a minimum of two base64 characters are
* required to construct a single output byte and payloads with
* a remainder of 1 are illegal.
*
* Should never be the case due to length check above.
*/
assert(remainder != 1);
unsigned value = DECODE_BASE64[pos[0] & 0x7F]; unsigned value = DECODE_BASE64[pos[0] & 0x7F];
value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F];
if (remainder == 3) { if (remainder == 3) {
@ -132,7 +152,8 @@ std::uint8_t const * olm::decode_base64(
} }
output[0] = value; output[0] = value;
} }
return input + input_length;
return raw_length;
} }
@ -162,6 +183,5 @@ size_t _olm_decode_base64(
uint8_t const * input, size_t input_length, uint8_t const * input, size_t input_length,
uint8_t * output uint8_t * output
) { ) {
olm::decode_base64(input, input_length, output); return olm::decode_base64(input, input_length, output);
return olm::decode_base64_length(input_length);
} }

View file

@ -66,5 +66,24 @@ assert_equals(std::size_t(11), output_length);
assert_equals(expected_output, output, output_length); assert_equals(expected_output, output, output_length);
} }
{
TestCase test_case("Decoding base64 of invalid length fails with -1");
#include <iostream>
std::uint8_t input[] = "SGVsbG8gV29ybGQab";
std::size_t input_length = sizeof(input) - 1;
/* We use a longer but valid input length here so that we don't get back -1.
* Nothing will be written to the output buffer anyway because the input is
* invalid. */
std::size_t buf_length = olm::decode_base64_length(input_length + 1);
std::uint8_t output[buf_length];
std::uint8_t expected_output[buf_length];
memset(output, 0, buf_length);
memset(expected_output, 0, buf_length);
std::size_t output_length = ::_olm_decode_base64(input, input_length, output);
assert_equals(std::size_t(-1), output_length);
assert_equals(0, memcmp(output, expected_output, buf_length));
}
} }