From 452d5960e4b540c50d1ab238e9fddae4a3bd1ece Mon Sep 17 00:00:00 2001 From: moiseev-signal <122060238+moiseev-signal@users.noreply.github.com> Date: Thu, 2 May 2024 06:28:32 -0700 Subject: [PATCH] Add test and extra cleanup around usage of incremental mac. --- .../crypto/AttachmentCipherInputStream.java | 166 ++++++++---------- .../api/crypto/AttachmentCipherTest.java | 2 - 2 files changed, 77 insertions(+), 91 deletions(-) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java index e744df0609..b5352fe950 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java @@ -6,7 +6,6 @@ package org.whispersystems.signalservice.api.crypto; -import org.signal.libsignal.protocol.InvalidMacException; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream; @@ -63,52 +62,45 @@ public class AttachmentCipherInputStream extends FilterInputStream { public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) throws InvalidMessageException, IOException { - try { - byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); - Mac mac = Mac.getInstance("HmacSHA256"); - mac.init(new SecretKeySpec(parts[1], "HmacSHA256")); + byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); + Mac mac = initMac(parts[1]); - if (file.length() <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - if (digest == null) { - throw new InvalidMacException("Missing digest!"); - } - - - final InputStream wrappedStream; - final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0; - - if (!hasIncrementalMac) { - try (FileInputStream macVerificationStream = new FileInputStream(file)) { - verifyMac(macVerificationStream, file.length(), mac, digest); - } - wrappedStream = new FileInputStream(file); - } else { - wrappedStream = new IncrementalMacInputStream( - new IncrementalMacAdditionalValidationsInputStream( - new FileInputStream(file), - file.length(), - mac, - digest - ), - parts[1], - ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), - incrementalDigest); - } - InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); - - if (plaintextLength != 0) { - inputStream = new ContentLengthInputStream(inputStream, plaintextLength); - } - - return inputStream; - } catch (NoSuchAlgorithmException | InvalidKeyException e) { - throw new AssertionError(e); - } catch (InvalidMacException e) { - throw new InvalidMessageException(e); + if (file.length() <= BLOCK_SIZE + mac.getMacLength()) { + throw new InvalidMessageException("Message shorter than crypto overhead!"); } + + if (digest == null) { + throw new InvalidMessageException("Missing digest!"); + } + + + final InputStream wrappedStream; + final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0; + + if (!hasIncrementalMac) { + try (FileInputStream macVerificationStream = new FileInputStream(file)) { + verifyMac(macVerificationStream, file.length(), mac, digest); + } + wrappedStream = new FileInputStream(file); + } else { + wrappedStream = new IncrementalMacInputStream( + new IncrementalMacAdditionalValidationsInputStream( + new FileInputStream(file), + file.length(), + mac, + digest + ), + parts[1], + ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), + incrementalDigest); + } + InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); + + if (plaintextLength != 0) { + inputStream = new ContentLengthInputStream(inputStream, plaintextLength); + } + + return inputStream; } /** @@ -117,55 +109,41 @@ public class AttachmentCipherInputStream extends FilterInputStream { public static InputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength) throws InvalidMessageException, IOException { - try { - Mac mac = Mac.getInstance("HmacSHA256"); - mac.init(new SecretKeySpec(archivedMediaKeyMaterial.getMacKey(), "HmacSHA256")); + Mac mac = initMac(archivedMediaKeyMaterial.getMacKey()); - if (file.length() <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - try (FileInputStream macVerificationStream = new FileInputStream(file)) { - verifyMac(macVerificationStream, file.length(), mac, null); - } - - InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); - - if (originalCipherTextLength != 0) { - inputStream = new ContentLengthInputStream(inputStream, originalCipherTextLength); - } - - return inputStream; - } catch (NoSuchAlgorithmException | InvalidKeyException e) { - throw new AssertionError(e); - } catch (InvalidMacException e) { - throw new InvalidMessageException(e); + if (file.length() <= BLOCK_SIZE + mac.getMacLength()) { + throw new InvalidMessageException("Message shorter than crypto overhead!"); } + + try (FileInputStream macVerificationStream = new FileInputStream(file)) { + verifyMac(macVerificationStream, file.length(), mac, null); + } + + InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); + + if (originalCipherTextLength != 0) { + inputStream = new ContentLengthInputStream(inputStream, originalCipherTextLength); + } + + return inputStream; } public static InputStream createForStickerData(byte[] data, byte[] packKey) throws InvalidMessageException, IOException { - try { - byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64); - byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); - Mac mac = Mac.getInstance("HmacSHA256"); - mac.init(new SecretKeySpec(parts[1], "HmacSHA256")); + byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64); + byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); + Mac mac = initMac(parts[1]); - if (data.length <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - try (InputStream inputStream = new ByteArrayInputStream(data)) { - verifyMac(inputStream, data.length, mac, null); - } - - return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength()); - } catch (NoSuchAlgorithmException | InvalidKeyException e) { - throw new AssertionError(e); - } catch (InvalidMacException e) { - throw new InvalidMessageException(e); + if (data.length <= BLOCK_SIZE + mac.getMacLength()) { + throw new InvalidMessageException("Message shorter than crypto overhead!"); } + + try (InputStream inputStream = new ByteArrayInputStream(data)) { + verifyMac(inputStream, data.length, mac, null); + } + + return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength()); } private AttachmentCipherInputStream(InputStream inputStream, byte[] cipherKey, long totalDataSize) @@ -297,8 +275,18 @@ public class AttachmentCipherInputStream extends FilterInputStream { } } + private static Mac initMac(byte[] key) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(key, "HmacSHA256")); + return mac; + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new AssertionError(e); + } + } + private static void verifyMac(@Nonnull InputStream inputStream, long length, @Nonnull Mac mac, @Nullable byte[] theirDigest) - throws InvalidMacException + throws InvalidMessageException { try { MessageDigest digest = MessageDigest.getInstance("SHA256"); @@ -317,17 +305,17 @@ public class AttachmentCipherInputStream extends FilterInputStream { Util.readFully(inputStream, theirMac); if (!MessageDigest.isEqual(ourMac, theirMac)) { - throw new InvalidMacException("MAC doesn't match!"); + throw new InvalidMessageException("MAC doesn't match!"); } byte[] ourDigest = digest.digest(theirMac); if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) { - throw new InvalidMacException("Digest doesn't match!"); + throw new InvalidMessageException("Digest doesn't match!"); } } catch (IOException | ArithmeticException e1) { - throw new InvalidMacException(e1); + throw new InvalidMessageException(e1); } catch (NoSuchAlgorithmException e) { throw new AssertionError(e); } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.java b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.java index 782c696c3e..0fdfda0afb 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.java +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.java @@ -7,7 +7,6 @@ import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.InvalidMacException; import org.signal.libsignal.protocol.kdf.HKDFv3; -import org.signal.libsignal.protocol.util.ByteUtil; import org.whispersystems.signalservice.api.backup.BackupKey; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory; @@ -16,7 +15,6 @@ import org.whispersystems.signalservice.internal.util.Util; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream;