Add test and extra cleanup around usage of incremental mac.

This commit is contained in:
moiseev-signal
2024-05-02 06:28:32 -07:00
committed by Alex Hart
parent c95b180728
commit 452d5960e4
2 changed files with 77 additions and 91 deletions

View File

@@ -6,7 +6,6 @@
package org.whispersystems.signalservice.api.crypto;
import org.signal.libsignal.protocol.InvalidMacException;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
@@ -63,52 +62,45 @@ public class AttachmentCipherInputStream extends FilterInputStream {
public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
throws InvalidMessageException, IOException
{
try {
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(parts[1], "HmacSHA256"));
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = initMac(parts[1]);
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
if (digest == null) {
throw new InvalidMacException("Missing digest!");
}
final InputStream wrappedStream;
final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0;
if (!hasIncrementalMac) {
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, digest);
}
wrappedStream = new FileInputStream(file);
} else {
wrappedStream = new IncrementalMacInputStream(
new IncrementalMacAdditionalValidationsInputStream(
new FileInputStream(file),
file.length(),
mac,
digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest);
}
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, plaintextLength);
}
return inputStream;
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
} catch (InvalidMacException e) {
throw new InvalidMessageException(e);
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
if (digest == null) {
throw new InvalidMessageException("Missing digest!");
}
final InputStream wrappedStream;
final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0;
if (!hasIncrementalMac) {
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, digest);
}
wrappedStream = new FileInputStream(file);
} else {
wrappedStream = new IncrementalMacInputStream(
new IncrementalMacAdditionalValidationsInputStream(
new FileInputStream(file),
file.length(),
mac,
digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest);
}
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, plaintextLength);
}
return inputStream;
}
/**
@@ -117,55 +109,41 @@ public class AttachmentCipherInputStream extends FilterInputStream {
public static InputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength)
throws InvalidMessageException, IOException
{
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(archivedMediaKeyMaterial.getMacKey(), "HmacSHA256"));
Mac mac = initMac(archivedMediaKeyMaterial.getMacKey());
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, null);
}
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
if (originalCipherTextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, originalCipherTextLength);
}
return inputStream;
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
} catch (InvalidMacException e) {
throw new InvalidMessageException(e);
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, null);
}
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
if (originalCipherTextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, originalCipherTextLength);
}
return inputStream;
}
public static InputStream createForStickerData(byte[] data, byte[] packKey)
throws InvalidMessageException, IOException
{
try {
byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64);
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(parts[1], "HmacSHA256"));
byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64);
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = initMac(parts[1]);
if (data.length <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (InputStream inputStream = new ByteArrayInputStream(data)) {
verifyMac(inputStream, data.length, mac, null);
}
return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength());
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
} catch (InvalidMacException e) {
throw new InvalidMessageException(e);
if (data.length <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (InputStream inputStream = new ByteArrayInputStream(data)) {
verifyMac(inputStream, data.length, mac, null);
}
return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength());
}
private AttachmentCipherInputStream(InputStream inputStream, byte[] cipherKey, long totalDataSize)
@@ -297,8 +275,18 @@ public class AttachmentCipherInputStream extends FilterInputStream {
}
}
private static Mac initMac(byte[] key) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(key, "HmacSHA256"));
return mac;
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
}
}
private static void verifyMac(@Nonnull InputStream inputStream, long length, @Nonnull Mac mac, @Nullable byte[] theirDigest)
throws InvalidMacException
throws InvalidMessageException
{
try {
MessageDigest digest = MessageDigest.getInstance("SHA256");
@@ -317,17 +305,17 @@ public class AttachmentCipherInputStream extends FilterInputStream {
Util.readFully(inputStream, theirMac);
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw new InvalidMacException("MAC doesn't match!");
throw new InvalidMessageException("MAC doesn't match!");
}
byte[] ourDigest = digest.digest(theirMac);
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
throw new InvalidMacException("Digest doesn't match!");
throw new InvalidMessageException("Digest doesn't match!");
}
} catch (IOException | ArithmeticException e1) {
throw new InvalidMacException(e1);
throw new InvalidMessageException(e1);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}

View File

@@ -7,7 +7,6 @@ import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException;
import org.signal.libsignal.protocol.kdf.HKDFv3;
import org.signal.libsignal.protocol.util.ByteUtil;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory;
@@ -16,7 +15,6 @@ import org.whispersystems.signalservice.internal.util.Util;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;