diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java index a0589e9..ed08c25 100644 --- a/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java +++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java @@ -106,6 +106,16 @@ public class OlmSAS { return null; } + public String calculateMacFixedBase64(String message, String info) throws OlmException { + try { + byte[] bytes = calculateMacFixedBase64Jni(message.getBytes("UTF-8"), info.getBytes("UTF-8")); + if (bytes != null) return new String(bytes, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new OlmException(OlmException.EXCEPTION_CODE_SAS_ERROR, e.getMessage()); + } + return null; + } + public String calculateMacLongKdf(String message, String info) throws OlmException { try { byte[] bytes = calculateMacLongKdfJni(message.getBytes("UTF-8"), info.getBytes("UTF-8")); @@ -140,6 +150,8 @@ public class OlmSAS { private native byte[] calculateMacJni(byte[] message, byte[] info); + private native byte[] calculateMacFixedBase64Jni(byte[] message, byte[] info); + private native byte[] calculateMacLongKdfJni(byte[] message, byte[] info); /** diff --git a/android/olm-sdk/src/main/jni/olm_sas.cpp b/android/olm-sdk/src/main/jni/olm_sas.cpp index 400934f..06411a5 100644 --- a/android/olm-sdk/src/main/jni/olm_sas.cpp +++ b/android/olm-sdk/src/main/jni/olm_sas.cpp @@ -309,6 +309,86 @@ JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacJni)(JNIEnv *env, jobject thiz return returnValue; } +JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacFixedBase64Jni)(JNIEnv *env, jobject thiz,jbyteArray messageBuffer,jbyteArray infoBuffer) { + LOGD("## calculateMacFixedBase64Jni(): IN"); + const char* errorMessage = NULL; + jbyteArray returnValue = 0; + OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); + + jbyte *messagePtr = NULL; + jboolean messageWasCopied = JNI_FALSE; + + jbyte *infoPtr = NULL; + jboolean infoWasCopied = JNI_FALSE; + + if (!sasPtr) + { + LOGE("## calculateMacFixedBase64Jni(): failure - invalid SAS ptr=NULL"); + errorMessage = "invalid SAS ptr=NULL"; + } else if(!messageBuffer) { + LOGE("## calculateMacFixedBase64Jni(): failure - invalid message"); + errorMessage = "invalid info"; + } + else if (!(messagePtr = env->GetByteArrayElements(messageBuffer, &messageWasCopied))) + { + LOGE(" ## calculateMacFixedBase64Jni(): failure - message JNI allocation OOM"); + errorMessage = "message JNI allocation OOM"; + } + else if (!(infoPtr = env->GetByteArrayElements(infoBuffer, &infoWasCopied))) + { + LOGE(" ## calculateMacFixedBase64Jni(): failure - info JNI allocation OOM"); + errorMessage = "info JNI allocation OOM"; + } else { + + size_t infoLength = (size_t)env->GetArrayLength(infoBuffer); + size_t messageLength = (size_t)env->GetArrayLength(messageBuffer); + size_t macLength = olm_sas_mac_length(sasPtr); + + void *macPtr = malloc(macLength*sizeof(uint8_t)); + + size_t result = olm_sas_calculate_mac_fixed_base64(sasPtr,messagePtr,messageLength,infoPtr,infoLength,macPtr,macLength); + if (result == olm_error()) + { + errorMessage = (const char *)olm_sas_last_error(sasPtr); + LOGE("## calculateMacFixedBase64Jni(): failure - error calculating SAS mac Msg=%s", errorMessage); + } + else + { + returnValue = env->NewByteArray(macLength); + env->SetByteArrayRegion(returnValue, 0 , macLength, (jbyte*)macPtr); + } + + if (macPtr) { + free(macPtr); + } + } + + // free alloc + if (infoPtr) + { + if (infoWasCopied) + { + memset(infoPtr, 0, (size_t)env->GetArrayLength(infoBuffer)); + } + env->ReleaseByteArrayElements(infoBuffer, infoPtr, JNI_ABORT); + } + if (messagePtr) + { + if (messageWasCopied) + { + memset(messagePtr, 0, (size_t)env->GetArrayLength(messageBuffer)); + } + env->ReleaseByteArrayElements(messageBuffer, messagePtr, JNI_ABORT); + } + + if (errorMessage) + { + env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); + } + + return returnValue; +} + JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobject thiz,jbyteArray messageBuffer,jbyteArray infoBuffer) { LOGD("## calculateMacLongKdfJni(): IN"); const char* errorMessage = NULL; @@ -387,4 +467,4 @@ JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobje } return returnValue; -} \ No newline at end of file +} diff --git a/android/olm-sdk/src/main/jni/olm_sas.h b/android/olm-sdk/src/main/jni/olm_sas.h index 3340459..967c0fb 100644 --- a/android/olm-sdk/src/main/jni/olm_sas.h +++ b/android/olm-sdk/src/main/jni/olm_sas.h @@ -32,6 +32,7 @@ JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(getPubKeyJni)(JNIEnv *env, jobject thiz); JNIEXPORT void OLM_SAS_FUNC_DEF(setTheirPubKey)(JNIEnv *env, jobject thiz,jbyteArray pubKey); JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(generateShortCodeJni)(JNIEnv *env, jobject thiz, jbyteArray infoStringBytes, jint byteNb); JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacJni)(JNIEnv *env, jobject thiz, jbyteArray messageBuffer, jbyteArray infoBuffer); +JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacFixedBase64Jni)(JNIEnv *env, jobject thiz, jbyteArray messageBuffer, jbyteArray infoBuffer); JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobject thiz, jbyteArray messageBuffer, jbyteArray infoBuffer); #ifdef __cplusplus diff --git a/javascript/index.d.ts b/javascript/index.d.ts index 76efba9..566b649 100644 --- a/javascript/index.d.ts +++ b/javascript/index.d.ts @@ -121,6 +121,7 @@ declare class SAS { set_their_key(their_key: string): void; generate_bytes(info: string, length: number): Uint8Array; calculate_mac(input: string, info: string): string; + calculate_mac_fixed_base64(input: string, info: string): string; calculate_mac_long_kdf(input: string, info: string): string; } diff --git a/javascript/olm_sas.js b/javascript/olm_sas.js index 38535d5..be691ad 100644 --- a/javascript/olm_sas.js +++ b/javascript/olm_sas.js @@ -82,6 +82,22 @@ SAS.prototype['calculate_mac'] = restore_stack(function(input, info) { return UTF8ToString(mac_buffer, mac_length); }); +SAS.prototype['calculate_mac_fixed_base64'] = restore_stack(function(input, info) { + var input_array = array_from_string(input); + var input_buffer = stack(input_array); + var info_array = array_from_string(info); + var info_buffer = stack(info_array); + var mac_length = sas_method(Module['_olm_sas_mac_length'])(this.ptr); + var mac_buffer = stack(mac_length + NULL_BYTE_PADDING_LENGTH); + sas_method(Module['_olm_sas_calculate_mac_fixed_base64'])( + this.ptr, + input_buffer, input_array.length, + info_buffer, info_array.length, + mac_buffer, mac_length + ); + return UTF8ToString(mac_buffer, mac_length); +}); + SAS.prototype['calculate_mac_long_kdf'] = restore_stack(function(input, info) { var input_array = array_from_string(input); var input_buffer = stack(input_array); diff --git a/python/olm/sas.py b/python/olm/sas.py index cf2a443..b90d352 100644 --- a/python/olm/sas.py +++ b/python/olm/sas.py @@ -210,6 +210,37 @@ class Sas(object): ) return bytes_to_native_str(ffi.unpack(mac_buffer, mac_length)) + def calculate_mac_fixed_base64(self, message, extra_info): + # type: (str, str) -> str + """Generate a message authentication code based on the shared secret. + + Args: + message (str): The message to produce the authentication code for. + extra_info (str): Extra information to mix in when generating the + MAC + + Raises OlmSasError on failure. + + """ + byte_message = to_bytes(message) + byte_info = to_bytes(extra_info) + + mac_length = lib.olm_sas_mac_length(self._sas) + mac_buffer = ffi.new("char[]", mac_length) + + self._check_error( + lib.olm_sas_calculate_mac_fixed_base64( + self._sas, + ffi.from_buffer(byte_message), + len(byte_message), + ffi.from_buffer(byte_info), + len(byte_info), + mac_buffer, + mac_length + ) + ) + return bytes_to_native_str(ffi.unpack(mac_buffer, mac_length)) + def calculate_mac_long_kdf(self, message, extra_info): # type: (str, str) -> str """Generate a message authentication code based on the shared secret. diff --git a/xcode/OLMKit/OLMSAS.h b/xcode/OLMKit/OLMSAS.h index 3785b03..fa4e489 100644 --- a/xcode/OLMKit/OLMSAS.h +++ b/xcode/OLMKit/OLMSAS.h @@ -54,6 +54,17 @@ NS_ASSUME_NONNULL_BEGIN */ - (NSString *)calculateMac:(NSString*)input info:(NSString*)info error:(NSError* _Nullable *)error; +/** + Generate a message authentication code (MAC) based on the shared secret. + This version is compatible with other base64 implementations. + + @param input the message to produce the authentication code for. + @param info extra information to mix in when generating the MAC, as per the Matrix spec. + @param error the error if any. + @return the MAC. + */ +- (NSString *)calculateMacFixedBase64:(NSString*)input info:(NSString*)info error:(NSError* _Nullable *)error; + /** Generate a message authentication code (MAC) based on the shared secret. For compatibility with an old version of olm.js. diff --git a/xcode/OLMKit/OLMSAS.m b/xcode/OLMKit/OLMSAS.m index fed370b..21372ad 100644 --- a/xcode/OLMKit/OLMSAS.m +++ b/xcode/OLMKit/OLMSAS.m @@ -137,6 +137,40 @@ return mac; } +- (NSString *)calculateMacFixedBase64:(NSString *)input info:(NSString *)info error:(NSError *__autoreleasing _Nullable *)error { + NSMutableData *inputData = [input dataUsingEncoding:NSUTF8StringEncoding].mutableCopy; + NSData *infoData = [info dataUsingEncoding:NSUTF8StringEncoding]; + + size_t macLength = olm_sas_mac_length(olmSAS); + NSMutableData *macData = [NSMutableData dataWithLength:macLength]; + if (!macData) { + return nil; + } + + size_t result = olm_sas_calculate_mac_fixed_base64(olmSAS, + inputData.mutableBytes, inputData.length, + infoData.bytes, infoData.length, + macData.mutableBytes, macLength); + if (result == olm_error()) { + const char *olm_error = olm_sas_last_error(olmSAS); + NSLog(@"[OLMSAS] calculateMac: olm_sas_calculate_mac error: %s", olm_error); + + NSString *errorString = [NSString stringWithUTF8String:olm_error]; + if (error && olm_error && errorString) { + *error = [NSError errorWithDomain:OLMErrorDomain + code:0 + userInfo:@{ + NSLocalizedDescriptionKey: errorString, + NSLocalizedFailureReasonErrorKey: [NSString stringWithFormat:@"olm_sas_calculate_mac error: %@", errorString] + }]; + } + return nil; + } + + NSString *mac = [[NSString alloc] initWithData:macData encoding:NSUTF8StringEncoding]; + return mac; +} + - (NSString *)calculateMacLongKdf:(NSString *)input info:(NSString *)info error:(NSError *__autoreleasing _Nullable *)error { NSMutableData *inputData = [input dataUsingEncoding:NSUTF8StringEncoding].mutableCopy; NSData *infoData = [info dataUsingEncoding:NSUTF8StringEncoding];