diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkSigning.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkSigning.java new file mode 100644 index 0000000..89cb556 --- /dev/null +++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkSigning.java @@ -0,0 +1,100 @@ +/* + * Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.olm; + +import android.util.Log; + +import java.util.Arrays; + +public class OlmPkSigning { + private static final String LOG_TAG = "OlmPkSigning"; + + /** PK Signing Id returned by JNI. + * This value uniquely identifies the native PK signing instance. + **/ + private transient long mNativeId; + + public OlmPkSigning() throws OlmException { + try { + mNativeId = createNewPkSigningJni(); + } catch (Exception e) { + throw new OlmException(OlmException.EXCEPTION_CODE_PK_SIGNING_CREATION, e.getMessage()); + } + } + + private native long createNewPkSigningJni(); + + private native void releasePkSigningJni(); + + public void releaseSigning() { + if (0 != mNativeId) { + releasePkSigningJni(); + } + mNativeId = 0; + } + + public boolean isReleased() { + return (0 == mNativeId); + } + + public static native int seedLength(); + + public static byte[] generateSeed() throws OlmException { + try { + return generateSeedJni(); + } catch (Exception e) { + Log.e(LOG_TAG, "## generateSeed(): failed " + e.getMessage()); + throw new OlmException(OlmException.EXCEPTION_CODE_PK_SIGNING_GENERATE_SEED, e.getMessage()); + } + } + + public static native byte[] generateSeedJni(); + + public String initWithSeed(byte[] seed) throws OlmException { + try { + byte[] pubKey = setKeyFromSeedJni(seed); + return new String(pubKey, "UTF-8"); + } catch (Exception e) { + Log.e(LOG_TAG, "## initWithSeed(): failed " + e.getMessage()); + throw new OlmException(OlmException.EXCEPTION_CODE_PK_SIGNING_INIT_WITH_SEED, e.getMessage()); + } + } + + public native byte[] setKeyFromSeedJni(byte[] seed); + + public String sign(String aMessage) throws OlmException { + if (null == aMessage) { + return null; + } + + byte[] messageBuffer = null; + try { + messageBuffer = aMessage.getBytes("UTF-8"); + byte[] signature = pkSignJni(messageBuffer); + return new String(signature, "UTF-8"); + } catch (Exception e) { + Log.e(LOG_TAG, "## pkSign(): failed " + e.getMessage()); + throw new OlmException(OlmException.EXCEPTION_CODE_PK_SIGNING_SIGN, e.getMessage()); + } finally { + if (null != messageBuffer) { + Arrays.fill(messageBuffer, (byte) 0); + } + } + } + + private native byte[] pkSignJni(byte[] message); +} diff --git a/include/olm/sas.h b/include/olm/sas.h index 46d4176..ec90ae7 100644 --- a/include/olm/sas.h +++ b/include/olm/sas.h @@ -147,6 +147,14 @@ size_t olm_sas_calculate_mac( void * mac, size_t mac_length ); +// for compatibility with an old version of Riot +size_t olm_sas_calculate_mac_long_kdf( + OlmSAS * sas, + void * input, size_t input_length, + const void * info, size_t info_length, + void * mac, size_t mac_length +); + /** @} */ // end of SAS group #ifdef __cplusplus diff --git a/javascript/olm_sas.js b/javascript/olm_sas.js index d5044ce..a2f82ee 100644 --- a/javascript/olm_sas.js +++ b/javascript/olm_sas.js @@ -75,3 +75,19 @@ SAS.prototype['calculate_mac'] = restore_stack(function(input, info) { ); return Pointer_stringify(mac_buffer); }); + +SAS.prototype['calculate_mac_long_kdf'] = restore_stack(function(input, info) { + var input_array = array_from_string(input); + var input_buffer = stack(input_array); + var info_array = array_from_string(info); + var info_buffer = stack(info_array); + var mac_length = sas_method(Module['_olm_sas_mac_length'])(this.ptr); + var mac_buffer = stack(mac_length + NULL_BYTE_PADDING_LENGTH); + sas_method(Module['_olm_sas_calculate_mac_long_kdf'])( + this.ptr, + input_buffer, input_array.length, + info_buffer, info_array.length, + mac_buffer, mac_length + ); + return Pointer_stringify(mac_buffer); +}); diff --git a/python/MANIFEST.in b/python/MANIFEST.in index bfddd4f..db6309d 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,3 +1,4 @@ -include olm.h +include include/olm/olm.h +include include/olm/pk.h include Makefile include olm_build.py diff --git a/python/Makefile b/python/Makefile index 6283fb5..e4d0611 100644 --- a/python/Makefile +++ b/python/Makefile @@ -1,18 +1,26 @@ all: olm-python2 olm-python3 -include/olm/olm.h: ../include/olm/olm.h ../include/olm/inbound_group_session.h ../include/olm/outbound_group_session.h +OLM_HEADERS = ../include/olm/olm.h ../include/olm/inbound_group_session.h \ + ../include/olm/outbound_group_session.h \ + +include/olm/olm.h: $(OLM_HEADERS) mkdir -p include/olm $(CPP) -I dummy -I ../include ../include/olm/olm.h -o include/olm/olm.h # add memset to the header so that we can use it to clear buffers echo 'void *memset(void *s, int c, size_t n);' >> include/olm/olm.h +include/olm/pk.h: include/olm/olm.h ../include/olm/pk.h + $(CPP) -I dummy -I ../include ../include/olm/pk.h -o include/olm/pk.h + include/olm/sas.h: include/olm/olm.h ../include/olm/sas.h $(CPP) -I dummy -I ../include ../include/olm/sas.h -o include/olm/sas.h -olm-python2: include/olm/olm.h include/olm/sas.h +headers: include/olm/olm.h include/olm/pk.h include/olm/sas.h + +olm-python2: headers DEVELOP=$(DEVELOP) python2 setup.py build -olm-python3: include/olm/olm.h include/olm/sas.h +olm-python3: headers DEVELOP=$(DEVELOP) python3 setup.py build install: install-python2 install-python3 diff --git a/python/olm/__init__.py b/python/olm/__init__.py index 9ac45b0..d92b0ab 100644 --- a/python/olm/__init__.py +++ b/python/olm/__init__.py @@ -36,4 +36,11 @@ from .group_session import ( OutboundGroupSession, OlmGroupSessionError ) +from .pk import ( + PkMessage, + PkEncryption, + PkDecryption, + PkEncryptionError, + PkDecryptionError +) from .sas import Sas, OlmSasError diff --git a/python/olm/pk.py b/python/olm/pk.py new file mode 100644 index 0000000..b67d5a4 --- /dev/null +++ b/python/olm/pk.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- +# libolm python bindings +# Copyright © 2018 Damir Jelić +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""libolm PK module. + +This module contains bindings to the PK part of the Olm library. +It contains two classes PkDecryption and PkEncryption that are used to +establish an encrypted communication channel using public key encryption. + +Examples: + >>> decryption = PkDecryption() + >>> encryption = PkEncryption(decryption.public_key) + >>> plaintext = "It's a secret to everybody." + >>> message = encryption.encrypt(plaintext) + >>> decrypted_plaintext = decryption.decrypt(message) + +""" + +from builtins import super +from typing import AnyStr, Type +from future.utils import bytes_to_native_str + +from _libolm import ffi, lib # type: ignore +from ._finalize import track_for_finalization +from ._compat import URANDOM, to_bytearray + + +class PkEncryptionError(Exception): + """libolm Pk encryption exception.""" + + +class PkDecryptionError(Exception): + """libolm Pk decryption exception.""" + + +def _clear_pk_encryption(pk_struct): + lib.olm_clear_pk_encryption(pk_struct) + + +class PkMessage(object): + """A PK encrypted message.""" + + def __init__(self, ephemeral_key, mac, ciphertext): + # type: (str, str, str) -> None + """Create a new PK encrypted message. + + Args: + ephemeral_key(str): the public part of the ephemeral key + used (together with the recipient's key) to generate a symmetric + encryption key. + mac(str): Message Authentication Code of the encrypted message + ciphertext(str): The cipher text of the encrypted message + """ + self.ephemeral_key = ephemeral_key + self.mac = mac + self.ciphertext = ciphertext + + +class PkEncryption(object): + """PkEncryption class. + + Represents the decryption part of a PK encrypted channel. + """ + + def __init__(self, recipient_key): + # type: (AnyStr) -> None + """Create a new PK encryption object. + + Args: + recipient_key(str): a public key that will be used for encryption + """ + if not recipient_key: + raise ValueError("Recipient key can't be empty") + + self._buf = ffi.new("char[]", lib.olm_pk_encryption_size()) + self._pk_encryption = lib.olm_pk_encryption(self._buf) + track_for_finalization(self, self._pk_encryption, _clear_pk_encryption) + + byte_key = to_bytearray(recipient_key) + lib.olm_pk_encryption_set_recipient_key( + self._pk_encryption, + ffi.from_buffer(byte_key), + len(byte_key) + ) + + # clear out copies of the key + if byte_key is not recipient_key: # pragma: no cover + for i in range(0, len(byte_key)): + byte_key[i] = 0 + + def _check_error(self, ret): # pragma: no cover + # type: (int) -> None + if ret != lib.olm_error(): + return + + last_error = bytes_to_native_str( + ffi.string(lib.olm_pk_encryption_last_error(self._pk_encryption))) + + raise PkEncryptionError(last_error) + + def encrypt(self, plaintext): + # type: (AnyStr) -> PkMessage + """Encrypt a message. + + Returns the encrypted PkMessage. + + Args: + plaintext(str): A string that will be encrypted using the + PkEncryption object. + """ + byte_plaintext = to_bytearray(plaintext) + + r_length = lib.olm_pk_encrypt_random_length(self._pk_encryption) + random = URANDOM(r_length) + random_buffer = ffi.new("char[]", random) + + ciphertext_length = lib.olm_pk_ciphertext_length( + self._pk_encryption, len(byte_plaintext) + ) + ciphertext = ffi.new("char[]", ciphertext_length) + + mac_length = lib.olm_pk_mac_length(self._pk_encryption) + mac = ffi.new("char[]", mac_length) + + ephemeral_key_size = lib.olm_pk_key_length() + ephemeral_key = ffi.new("char[]", ephemeral_key_size) + + ret = lib.olm_pk_encrypt( + self._pk_encryption, + ffi.from_buffer(byte_plaintext), len(byte_plaintext), + ciphertext, ciphertext_length, + mac, mac_length, + ephemeral_key, ephemeral_key_size, + random_buffer, r_length + ) + + try: + self._check_error(ret) + finally: # pragma: no cover + # clear out copies of plaintext + if byte_plaintext is not plaintext: + for i in range(0, len(byte_plaintext)): + byte_plaintext[i] = 0 + + message = PkMessage( + bytes_to_native_str( + ffi.unpack(ephemeral_key, ephemeral_key_size)), + bytes_to_native_str( + ffi.unpack(mac, mac_length)), + bytes_to_native_str( + ffi.unpack(ciphertext, ciphertext_length)) + ) + return message + + +def _clear_pk_decryption(pk_struct): + lib.olm_clear_pk_decryption(pk_struct) + + +class PkDecryption(object): + """PkDecryption class. + + Represents the decryption part of a PK encrypted channel. + + Attributes: + public_key (str): The public key of the PkDecryption object, can be + shared and used to create a PkEncryption object. + + """ + + def __new__(cls): + # type: (Type[PkDecryption]) -> PkDecryption + obj = super().__new__(cls) + obj._buf = ffi.new("char[]", lib.olm_pk_decryption_size()) + obj._pk_decryption = lib.olm_pk_decryption(obj._buf) + obj.public_key = None + track_for_finalization(obj, obj._pk_decryption, _clear_pk_decryption) + return obj + + def __init__(self): + if False: # pragma: no cover + self._pk_decryption = self._pk_decryption # type: ffi.cdata + + random_length = lib.olm_pk_private_key_length() + random = URANDOM(random_length) + random_buffer = ffi.new("char[]", random) + + key_length = lib.olm_pk_key_length() + key_buffer = ffi.new("char[]", key_length) + + ret = lib.olm_pk_key_from_private( + self._pk_decryption, + key_buffer, key_length, + random_buffer, random_length + ) + self._check_error(ret) + self.public_key = bytes_to_native_str(ffi.unpack( + key_buffer, + key_length + )) + + def _check_error(self, ret): + # type: (int) -> None + if ret != lib.olm_error(): + return + + last_error = bytes_to_native_str( + ffi.string(lib.olm_pk_decryption_last_error(self._pk_decryption))) + + raise PkDecryptionError(last_error) + + def pickle(self, passphrase=""): + # type: (str) -> bytes + """Store a PkDecryption object. + + Stores a PkDecryption object as a base64 string. Encrypts the object + using the supplied passphrase. Returns a byte object containing the + base64 encoded string of the pickled session. + + Args: + passphrase(str, optional): The passphrase to be used to encrypt + the object. + """ + byte_key = to_bytearray(passphrase) + + pickle_length = lib.olm_pickle_pk_decryption_length( + self._pk_decryption + ) + pickle_buffer = ffi.new("char[]", pickle_length) + + ret = lib.olm_pickle_pk_decryption( + self._pk_decryption, + ffi.from_buffer(byte_key), len(byte_key), + pickle_buffer, pickle_length + ) + try: + self._check_error(ret) + finally: + # zero out copies of the passphrase + for i in range(0, len(byte_key)): + byte_key[i] = 0 + + return ffi.unpack(pickle_buffer, pickle_length) + + @classmethod + def from_pickle(cls, pickle, passphrase=""): + # types: (bytes, str) -> PkDecryption + """Restore a previously stored PkDecryption object. + + Creates a PkDecryption object from a pickled base64 string. Decrypts + the pickled object using the supplied passphrase. + Raises PkDecryptionError on failure. If the passphrase + doesn't match the one used to encrypt the session then the error + message for the exception will be "BAD_ACCOUNT_KEY". If the base64 + couldn't be decoded then the error message will be "INVALID_BASE64". + + Args: + pickle(bytes): Base64 encoded byte string containing the pickled + PkDecryption object + passphrase(str, optional): The passphrase used to encrypt the + object + """ + if not pickle: + raise ValueError("Pickle can't be empty") + + byte_key = to_bytearray(passphrase) + pickle_buffer = ffi.new("char[]", pickle) + + pubkey_length = lib.olm_pk_key_length() + pubkey_buffer = ffi.new("char[]", pubkey_length) + + obj = cls.__new__(cls) + + ret = lib.olm_unpickle_pk_decryption( + obj._pk_decryption, + ffi.from_buffer(byte_key), len(byte_key), + pickle_buffer, len(pickle), + pubkey_buffer, pubkey_length) + + try: + obj._check_error(ret) + finally: + for i in range(0, len(byte_key)): + byte_key[i] = 0 + + obj.public_key = bytes_to_native_str(ffi.unpack( + pubkey_buffer, + pubkey_length + )) + + return obj + + def decrypt(self, message): + # type (PkMessage) -> str + """Decrypt a previously encrypted Pk message. + + Returns the decrypted plaintext. + Raises PkDecryptionError on failure. + + Args: + message(PkMessage): the pk message to decrypt. + """ + ephemeral_key = to_bytearray(message.ephemeral_key) + ephemeral_key_size = len(ephemeral_key) + + mac = to_bytearray(message.mac) + mac_length = len(mac) + + ciphertext = to_bytearray(message.ciphertext) + ciphertext_length = len(ciphertext) + + max_plaintext_length = lib.olm_pk_max_plaintext_length( + self._pk_decryption, + ciphertext_length + ) + plaintext_buffer = ffi.new("char[]", max_plaintext_length) + + ret = lib.olm_pk_decrypt( + self._pk_decryption, + ffi.from_buffer(ephemeral_key), ephemeral_key_size, + ffi.from_buffer(mac), mac_length, + ffi.from_buffer(ciphertext), ciphertext_length, + plaintext_buffer, max_plaintext_length) + self._check_error(ret) + + plaintext = (ffi.unpack( + plaintext_buffer, + ret + )) + + # clear out copies of the plaintext + lib.memset(plaintext_buffer, 0, max_plaintext_length) + + return bytes_to_native_str(plaintext) diff --git a/python/olm_build.py b/python/olm_build.py index 1c610a1..0606337 100644 --- a/python/olm_build.py +++ b/python/olm_build.py @@ -18,6 +18,7 @@ from __future__ import unicode_literals import os +import subprocess from cffi import FFI @@ -32,6 +33,8 @@ link_args = ["-L../build"] if DEVELOP and DEVELOP.lower() in ["yes", "true", "1"]: link_args.append('-Wl,-rpath=../build') +headers_build = subprocess.Popen("make headers", shell=True) +headers_build.wait() ffibuilder.set_source( "_libolm", @@ -39,6 +42,7 @@ ffibuilder.set_source( #include #include #include + #include #include """, libraries=["olm"], @@ -48,6 +52,9 @@ ffibuilder.set_source( with open(os.path.join(PATH, "include/olm/olm.h")) as f: ffibuilder.cdef(f.read(), override=True) +with open(os.path.join(PATH, "include/olm/pk.h")) as f: + ffibuilder.cdef(f.read(), override=True) + with open(os.path.join(PATH, "include/olm/sas.h")) as f: ffibuilder.cdef(f.read(), override=True) diff --git a/python/setup.py b/python/setup.py index 4b0deb1..5742fd9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -22,6 +22,10 @@ setup( packages=["olm"], setup_requires=["cffi>=1.0.0"], cffi_modules=["olm_build.py:ffibuilder"], - install_requires=["cffi>=1.0.0", "future", "typing"], + install_requires=[ + "cffi>=1.0.0", + "future", + "typing;python_version<'3.5'" + ], zip_safe=False ) diff --git a/python/tests/pk_test.py b/python/tests/pk_test.py new file mode 100644 index 0000000..f2aa147 --- /dev/null +++ b/python/tests/pk_test.py @@ -0,0 +1,49 @@ +import pytest + +from olm import PkDecryption, PkDecryptionError, PkEncryption + + +class TestClass(object): + def test_invalid_encryption(self): + with pytest.raises(ValueError): + PkEncryption("") + + def test_decrytion(self): + decryption = PkDecryption() + encryption = PkEncryption(decryption.public_key) + plaintext = "It's a secret to everybody." + message = encryption.encrypt(plaintext) + decrypted_plaintext = decryption.decrypt(message) + isinstance(decrypted_plaintext, str) + assert plaintext == decrypted_plaintext + + def test_invalid_decrytion(self): + decryption = PkDecryption() + encryption = PkEncryption(decryption.public_key) + plaintext = "It's a secret to everybody." + message = encryption.encrypt(plaintext) + message.ephemeral_key = "?" + with pytest.raises(PkDecryptionError): + decryption.decrypt(message) + + def test_pickling(self): + decryption = PkDecryption() + encryption = PkEncryption(decryption.public_key) + plaintext = "It's a secret to everybody." + message = encryption.encrypt(plaintext) + + pickle = decryption.pickle() + unpickled = PkDecryption.from_pickle(pickle) + decrypted_plaintext = unpickled.decrypt(message) + assert plaintext == decrypted_plaintext + + def test_invalid_unpickling(self): + with pytest.raises(ValueError): + PkDecryption.from_pickle("") + + def test_invalid_pass_pickling(self): + decryption = PkDecryption() + pickle = decryption.pickle("Secret") + + with pytest.raises(PkDecryptionError): + PkDecryption.from_pickle(pickle, "Not secret") diff --git a/src/sas.c b/src/sas.c index c5be73f..b5a3131 100644 --- a/src/sas.c +++ b/src/sas.c @@ -139,3 +139,26 @@ size_t olm_sas_calculate_mac( _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac); return 0; } + +// for compatibility with an old version of Riot +size_t olm_sas_calculate_mac_long_kdf( + OlmSAS * sas, + void * input, size_t input_length, + const void * info, size_t info_length, + void * mac, size_t mac_length +) { + if (mac_length < olm_sas_mac_length(sas)) { + sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + uint8_t key[256]; + _olm_crypto_hkdf_sha256( + sas->secret, sizeof(sas->secret), + NULL, 0, + (const uint8_t *) info, info_length, + key, 256 + ); + _olm_crypto_hmac_sha256(key, 256, input, input_length, mac); + _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac); + return 0; +}