Refactor AttachmentCipherInputStream.

This commit is contained in:
Greyson Parrelli
2025-06-18 10:35:45 -04:00
committed by Michelle Tang
parent 9798f5cc7c
commit 4f6a5de227
10 changed files with 572 additions and 403 deletions

View File

@@ -199,7 +199,7 @@ class AttachmentTableTest {
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first() val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table // Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false) val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
// Verify the digest has been updated to the properly padded one // Verify the digest has been updated to the properly padded one
@@ -230,7 +230,7 @@ class AttachmentTableTest {
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first() val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table // Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false) val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
// Verify the digest hasn't changed // Verify the digest hasn't changed

View File

@@ -252,7 +252,6 @@ class RestoreAttachmentJob private constructor(
pointer, pointer,
attachmentFile, attachmentFile,
maxReceiveSize, maxReceiveSize,
false,
progressListener progressListener
) )
} else { } else {

View File

@@ -132,14 +132,13 @@ class RestoreAttachmentThumbnailJob private constructor(
Log.i(TAG, "Downloading thumbnail for $attachmentId") Log.i(TAG, "Downloading thumbnail for $attachmentId")
val downloadResult = AppDependencies.signalServiceMessageReceiver val downloadResult = AppDependencies.signalServiceMessageReceiver
.retrieveArchivedAttachment( .retrieveArchivedThumbnail(
SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireThumbnailMediaName()), SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireThumbnailMediaName()),
cdnCredentials, cdnCredentials,
thumbnailTransferFile, thumbnailTransferFile,
pointer, pointer,
thumbnailFile, thumbnailFile,
maxThumbnailSize, maxThumbnailSize,
true,
progressListener progressListener
) )

View File

@@ -149,7 +149,15 @@ class RestoreLocalAttachmentJob private constructor(
try { try {
val iv = ByteArray(16) val iv = ByteArray(16)
streamSupplier.openStream().use { StreamUtil.readFully(it, iv) } streamSupplier.openStream().use { StreamUtil.readFully(it, iv) }
AttachmentCipherInputStream.createForAttachment(streamSupplier, size, attachment.size, combinedKey, attachment.remoteDigest, null, 0, false).use { input -> AttachmentCipherInputStream.createForAttachment(
streamSupplier = streamSupplier,
streamLength = size,
plaintextLength = attachment.size,
combinedKeyMaterial = combinedKey,
digest = attachment.remoteDigest,
incrementalDigest = null,
incrementalMacChunkSize = 0
).use { input ->
SignalDatabase.attachments.finalizeAttachmentAfterDownload(attachment.mmsId, attachment.attachmentId, input, iv) SignalDatabase.attachments.finalizeAttachmentAfterDownload(attachment.mmsId, attachment.attachmentId, input, iv)
} }
} catch (e: InvalidMessageException) { } catch (e: InvalidMessageException) {

View File

@@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.InvalidMessageException;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import java.io.File; import java.io.File;
import java.io.FileInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Optional; import java.util.Optional;
@@ -37,7 +38,12 @@ class AttachmentStreamLocalUriFetcher implements DataFetcher<InputStream> {
public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) { public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) {
try { try {
if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!"); if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!");
is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get(), null, 0); is = AttachmentCipherInputStream.createForAttachment(attachment,
plaintextLength,
key,
digest.get(),
null,
0);
callback.onDataReady(is); callback.onDataReady(is);
} catch (IOException | InvalidMessageException e) { } catch (IOException | InvalidMessageException e) {
callback.onLoadFailed(e); callback.onLoadFailed(e);

View File

@@ -71,7 +71,8 @@ class PartDataSource implements DataSource {
final boolean hasData = attachment.hasData; final boolean hasData = attachment.hasData;
if (inProgress && !hasData && hasIncrementalDigest && attachmentKey != null) { if (inProgress && !hasData && hasIncrementalDigest && attachmentKey != null) {
final byte[] decode = Base64.decode(attachmentKey); final byte[] decodedKey = Base64.decode(attachmentKey);
if (attachment.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) { if (attachment.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) {
final File archiveFile = attachmentDatabase.getOrCreateArchiveTransferFile(attachment.attachmentId); final File archiveFile = attachmentDatabase.getOrCreateArchiveTransferFile(attachment.attachmentId);
try { try {
@@ -81,7 +82,11 @@ class PartDataSource implements DataSource {
MediaRootBackupKey.MediaKeyMaterial mediaKeyMaterial = SignalStore.backup().getMediaRootBackupKey().deriveMediaSecretsFromMediaId(mediaId); MediaRootBackupKey.MediaKeyMaterial mediaKeyMaterial = SignalStore.backup().getMediaRootBackupKey().deriveMediaSecretsFromMediaId(mediaId);
long originalCipherLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size)); long originalCipherLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
this.inputStream = AttachmentCipherInputStream.createStreamingForArchivedAttachment(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, attachment.remoteDigest, decode, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); if (attachment.remoteDigest == null) {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e); throw new IOException("Error decrypting attachment stream!", e);
} }
@@ -95,7 +100,7 @@ class PartDataSource implements DataSource {
throw new InvalidMessageException("Missing digest!"); throw new InvalidMessageException("Missing digest!");
} }
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false); this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e); throw new IOException("Error decrypting attachment stream!", e);
} }

View File

@@ -7,9 +7,6 @@
package org.whispersystems.signalservice.api; package org.whispersystems.signalservice.api;
import org.signal.core.util.StreamUtil; import org.signal.core.util.StreamUtil;
import org.signal.core.util.concurrent.FutureTransformers;
import org.signal.core.util.concurrent.ListenableFuture;
import org.signal.core.util.concurrent.SettableFuture;
import org.signal.core.util.stream.LimitedInputStream; import org.signal.core.util.stream.LimitedInputStream;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
@@ -18,24 +15,15 @@ import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream; import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener; import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage; import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceStickerManifest; import org.whispersystems.signalservice.api.messages.SignalServiceStickerManifest;
import org.whispersystems.signalservice.api.profiles.ProfileAndCredential;
import org.whispersystems.signalservice.api.profiles.SignalServiceProfile;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException; import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.IdentityCheckRequest;
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse;
import org.whispersystems.signalservice.internal.push.PushServiceSocket; import org.whispersystems.signalservice.internal.push.PushServiceSocket;
import org.whispersystems.signalservice.internal.sticker.Pack; import org.whispersystems.signalservice.internal.sticker.Pack;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
import org.whispersystems.signalservice.internal.websocket.ResponseMapper;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@@ -46,15 +34,11 @@ import java.time.ZonedDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import io.reactivex.rxjava3.core.Single;
/** /**
* The primary interface for receiving Signal Service messages. * The primary interface for receiving Signal Service messages.
* *
@@ -118,6 +102,7 @@ public class SignalServiceMessageReceiver {
public AttachmentDownloadResult retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener) public AttachmentDownloadResult retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException { throws IOException, InvalidMessageException, MissingConfigurationException {
if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!"); if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!");
if (pointer.getKey() == null) throw new InvalidMessageException("No key!");
socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener); socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
@@ -127,7 +112,14 @@ public class SignalServiceMessageReceiver {
} }
return new AttachmentDownloadResult( return new AttachmentDownloadResult(
AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0), AttachmentCipherInputStream.createForAttachment(
destination,
pointer.getSize().orElse(0),
pointer.getKey(),
pointer.getDigest().get(),
null,
0
),
iv iv
); );
} }
@@ -150,14 +142,17 @@ public class SignalServiceMessageReceiver {
@Nonnull SignalServiceAttachmentPointer pointer, @Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination, @Nonnull File attachmentDestination,
long maxSizeBytes, long maxSizeBytes,
boolean ignoreDigest,
@Nullable ProgressListener listener) @Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException throws IOException, InvalidMessageException, MissingConfigurationException
{ {
if (!ignoreDigest && pointer.getDigest().isEmpty()) { if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("No attachment digest!"); throw new InvalidMessageException("No attachment digest!");
} }
if (pointer.getKey() == null) {
throw new InvalidMessageException("No key!");
}
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener); socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize() long originalCipherLength = pointer.getSize()
@@ -166,7 +161,7 @@ public class SignalServiceMessageReceiver {
.orElse(0L); .orElse(0L);
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer. // There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMedia(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) { try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) { try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line. // TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
StreamUtil.copy(backupDecrypted, fos); StreamUtil.copy(backupDecrypted, fos);
@@ -182,10 +177,63 @@ public class SignalServiceMessageReceiver {
attachmentDestination, attachmentDestination,
pointer.getSize().orElse(0), pointer.getSize().orElse(0),
pointer.getKey(), pointer.getKey(),
ignoreDigest ? null : pointer.getDigest().get(), pointer.getDigest().get(),
null, null,
0, 0
ignoreDigest );
return new AttachmentDownloadResult(dataStream, iv);
}
/**
* Retrieves an archived media attachment.
*
* @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media.
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress.
*
* @return An InputStream that streams the plaintext attachment contents.
*/
public AttachmentDownloadResult retrieveArchivedThumbnail(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial,
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
{
if (pointer.getKey() == null) {
throw new InvalidMessageException("No key!");
}
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize()
.filter(s -> s > 0)
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
.orElse(0L);
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
StreamUtil.copy(backupDecrypted, fos);
}
}
byte[] iv = new byte[16];
try (InputStream tempStream = new FileInputStream(attachmentDestination)) {
StreamUtil.readFully(tempStream, iv);
}
LimitedInputStream dataStream = AttachmentCipherInputStream.createForArchiveThumbnailInnerLayer(
attachmentDestination,
pointer.getSize().orElse(0),
pointer.getKey()
); );
return new AttachmentDownloadResult(dataStream, iv); return new AttachmentDownloadResult(dataStream, iv);

View File

@@ -5,8 +5,8 @@
*/ */
package org.whispersystems.signalservice.api.crypto package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.stream.LimitedInputStream import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits
import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
@@ -16,185 +16,316 @@ import org.whispersystems.signalservice.internal.util.Util
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.File import java.io.File
import java.io.FileInputStream import java.io.FileInputStream
import java.io.FilterInputStream
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.security.InvalidKeyException import java.security.InvalidKeyException
import java.security.MessageDigest import java.security.MessageDigest
import java.security.NoSuchAlgorithmException import java.security.NoSuchAlgorithmException
import javax.annotation.Nonnull import javax.annotation.Nonnull
import javax.crypto.BadPaddingException
import javax.crypto.Cipher import javax.crypto.Cipher
import javax.crypto.IllegalBlockSizeException
import javax.crypto.Mac import javax.crypto.Mac
import javax.crypto.ShortBufferException
import javax.crypto.spec.IvParameterSpec import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
import kotlin.math.min import kotlin.math.min
/** /**
* Class for streaming an encrypted push attachment off disk. * Decrypts an attachment stream that has been encrypted with AES/CBC/PKCS5Padding.
* *
* @author Moxie Marlinspike * It assumes that the first 16 bytes of the stream are the IV, and that the rest of the stream is encrypted data.
*/ */
class AttachmentCipherInputStream private constructor( object AttachmentCipherInputStream {
inputStream: InputStream,
aesKey: ByteArray,
private val totalDataSize: Long
) : FilterInputStream(inputStream) {
private val cipher: Cipher private const val BLOCK_SIZE = 16
private const val CIPHER_KEY_SIZE = 32
private const val MAC_KEY_SIZE = 32
private var done = false /**
private var totalRead: Long = 0 * Creates a stream to decrypt a typical attachment via a [File].
private var overflowBuffer: ByteArray? = null *
* @param incrementalDigest If null, incremental mac validation is disabled.
init { * @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
val iv = ByteArray(BLOCK_SIZE) */
readFullyWithoutDecrypting(iv) @JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") fun createForAttachment(
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv)) file: File,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
return create(
streamSupplier = { FileInputStream(file) },
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
} }
@Throws(IOException::class) /**
override fun read(): Int { * Creates a stream to decrypt a typical attachment via a [StreamSupplier].
val buffer = ByteArray(1) *
var read: Int = read(buffer) * @param incrementalDigest If null, incremental mac validation is disabled.
while (read == 0) { * @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
read = read(buffer) */
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
return create(
streamSupplier = streamSupplier,
streamLength = streamLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
}
/**
* After removing the server layer of encryption using [createForArchivedMediaOuterLayer], use this to decrypt the inner layer of the attachment.
* Thumbnails have a special path because we don't do any additional digest/hash validation on them.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchiveThumbnailInnerLayer(
file: File,
plaintextLength: Long,
combinedKeyMaterial: ByteArray
): LimitedInputStream {
return create(
streamSupplier = { FileInputStream(file) },
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = null,
incrementalDigest = null,
incrementalMacChunkSize = 0,
ignoreDigest = true
)
}
/**
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
* This will return a stream that unwraps the server's layer of encryption, giving you a stream that contains a "normally-encrypted" attachment.
*
* Because we're validating the encryptedDigest/plaintextHash of the inner layer, there's no additional out-of-band validation of this outer layer.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMediaOuterLayer(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
} }
if (read == -1) { FileInputStream(file).use { macVerificationStream ->
return read verifyMac(macVerificationStream, file.length(), mac, null)
} }
return buffer[0].toInt() and 0xFF val encryptedStream = FileInputStream(file)
val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, file.length() - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey)
val inputStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
return LimitedInputStream(inputStream, originalCipherTextLength)
} }
@Throws(IOException::class) /**
override fun read(@Nonnull buffer: ByteArray): Int { * When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
return read(buffer, 0, buffer.size) *
} * This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it.
*
* @param incrementalDigest If null, incremental mac validation is disabled.
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMediaOuterAndInnerLayers(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
@Throws(IOException::class) if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int { throw InvalidMessageException("Message shorter than crypto overhead!")
return if (totalRead != totalDataSize) {
readIncremental(buffer, offset, length)
} else if (!done) {
readFinal(buffer, offset, length)
} else {
-1
}
}
override fun markSupported(): Boolean = false
@Throws(IOException::class)
override fun skip(byteCount: Long): Long {
var skipped = 0L
while (skipped < byteCount) {
val remaining = byteCount - skipped
val buffer = ByteArray(min(4096, remaining.toInt()))
val read = read(buffer)
skipped += read.toLong()
} }
return skipped return create(
streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) },
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
} }
@Throws(IOException::class) /**
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int { * Creates a stream to decrypt sticker data. Stickers have a special path because the key material is derived from the pack key.
var offset = originalOffset */
var length = originalLength @JvmStatic
var readLength = 0 @Throws(InvalidMessageException::class, IOException::class)
fun createForStickerData(data: ByteArray, packKey: ByteArray): InputStream {
val keyMaterial = CombinedKeyMaterial.from(HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64))
val mac = initMac(keyMaterial.macKey)
overflowBuffer?.let { overflow -> if (data.size <= BLOCK_SIZE + mac.macLength) {
if (overflow.size > length) { throw InvalidMessageException("Message shorter than crypto overhead!")
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) }
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
return length ByteArrayInputStream(data).use { inputStream ->
} else if (overflow.size == length) { verifyMac(inputStream, data.size.toLong(), mac, null)
overflow.copyInto(destination = outputBuffer, destinationOffset = offset) }
overflowBuffer = null
return length val encryptedStream = ByteArrayInputStream(data)
} else { val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, data.size.toLong() - mac.macLength)
overflow.copyInto(destination = outputBuffer, destinationOffset = offset) val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
readLength += overflow.size
offset += readLength return BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
length -= readLength }
overflowBuffer = null
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
private fun create(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
): LimitedInputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
if (streamLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
}
if (!ignoreDigest && digest == null) {
throw InvalidMessageException("Missing digest!")
}
val wrappedStream: InputStream
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
if (!hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream ->
verifyMac(macVerificationStream, streamLength, mac, digest)
} }
} wrappedStream = streamSupplier.openStream()
if (length + totalRead > totalDataSize) {
length = (totalDataSize - totalRead).toInt()
}
val ciphertextBuffer = ByteArray(length)
val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) {
ciphertextBuffer.size
} else { } else {
// Ensure we leave the final block for readFinal() if (digest == null) {
ciphertextBuffer.size - cipher.blockSize throw InvalidMessageException("Missing digest for incremental mac validation!")
} }
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength)
totalRead += ciphertextRead.toLong()
wrappedStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
wrapped = streamSupplier.openStream(),
fileLength = streamLength,
mac = mac,
theirDigest = digest
),
keyMaterial.macKey,
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
}
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
return LimitedInputStream(decryptingStream, plaintextLength)
}
private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher {
val iv = inputStream.readNBytesOrThrow(BLOCK_SIZE)
return Cipher.getInstance("AES/CBC/PKCS5Padding").apply {
init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
}
}
private fun initMac(key: ByteArray): Mac {
try { try {
var plaintextLength = cipher.getOutputSize(ciphertextRead) val mac = Mac.getInstance("HmacSHA256")
mac.init(SecretKeySpec(key, "HmacSHA256"))
if (plaintextLength <= length) { return mac
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset) } catch (e: NoSuchAlgorithmException) {
return readLength throw AssertionError(e)
} } catch (e: InvalidKeyException) {
val plaintextBuffer = ByteArray(plaintextLength)
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
if (plaintextLength <= length) {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
readLength += plaintextLength
} else {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
readLength += length
}
return readLength
} catch (e: ShortBufferException) {
throw AssertionError(e) throw AssertionError(e)
} }
} }
@Throws(IOException::class) @Throws(InvalidMessageException::class)
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int { private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
try { try {
val internal = ByteArray(buffer.size) val digest = MessageDigest.getInstance("SHA256")
val actualLength = min(length, cipher.doFinal(internal, 0)) var remainingData = Util.toIntExact(length) - mac.macLength
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength) val buffer = ByteArray(4096)
done = true while (remainingData > 0) {
return actualLength val read = inputStream.read(buffer, 0, min(buffer.size, remainingData))
} catch (e: IllegalBlockSizeException) { mac.update(buffer, 0, read)
throw IOException(e) digest.update(buffer, 0, read)
} catch (e: BadPaddingException) { remainingData -= read
throw IOException(e) }
} catch (e: ShortBufferException) {
throw IOException(e) val ourMac = mac.doFinal()
val theirMac = ByteArray(mac.macLength)
Util.readFully(inputStream, theirMac)
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw InvalidMessageException("MAC doesn't match!")
}
val ourDigest = digest.digest(theirMac)
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
throw InvalidMessageException("Digest doesn't match!")
}
} catch (e: IOException) {
throw InvalidMessageException(e)
} catch (e: ArithmeticException) {
throw InvalidMessageException(e)
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
} }
} }
@Throws(IOException::class) private class CombinedKeyMaterial(val aesKey: ByteArray, val macKey: ByteArray) {
private fun readFullyWithoutDecrypting(buffer: ByteArray) { companion object {
var offset = 0 fun from(combinedKeyMaterial: ByteArray): CombinedKeyMaterial {
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
while (true) { return CombinedKeyMaterial(parts[0], parts[1])
val read = super.read(buffer, offset, buffer.size - offset)
if (read + offset < buffer.size) {
offset += read
} else {
return
} }
} }
} }
@@ -204,207 +335,4 @@ class AttachmentCipherInputStream private constructor(
@Throws(IOException::class) @Throws(IOException::class)
fun openStream(): InputStream fun openStream(): InputStream
} }
companion object {
private const val BLOCK_SIZE = 16
private const val CIPHER_KEY_SIZE = 32
private const val MAC_KEY_SIZE = 32
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
@JvmStatic
@JvmOverloads
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream {
return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest)
}
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
digest: ByteArray?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
): LimitedInputStream {
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (streamLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
}
if (!ignoreDigest && digest == null) {
throw InvalidMessageException("Missing digest!")
}
val wrappedStream: InputStream
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
if (!hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream ->
verifyMac(macVerificationStream, streamLength, mac, digest)
}
wrappedStream = streamSupplier.openStream()
} else {
wrappedStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
streamSupplier.openStream(),
streamLength,
mac,
digest!!
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
}
val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, plaintextLength)
}
/**
* Decrypt archived media to it's original attachment encrypted blob.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null)
}
val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, originalCipherTextLength)
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createStreamingForArchivedAttachment(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
val wrappedStream: InputStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
wrapped = archiveStream,
fileLength = file.length(),
mac = mac,
theirDigest = digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
val inputStream: InputStream = AttachmentCipherInputStream(
inputStream = wrappedStream,
aesKey = parts[0],
totalDataSize = file.length() - BLOCK_SIZE - mac.macLength
)
return if (plaintextLength != 0L) {
LimitedInputStream(inputStream, plaintextLength)
} else {
withoutLimits(inputStream)
}
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream {
val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (data.size <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null)
}
return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong())
}
private fun initMac(key: ByteArray): Mac {
try {
val mac = Mac.getInstance("HmacSHA256")
mac.init(SecretKeySpec(key, "HmacSHA256"))
return mac
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
} catch (e: InvalidKeyException) {
throw AssertionError(e)
}
}
@Throws(InvalidMessageException::class)
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
try {
val digest = MessageDigest.getInstance("SHA256")
var remainingData = Util.toIntExact(length) - mac.macLength
val buffer = ByteArray(4096)
while (remainingData > 0) {
val read = inputStream.read(buffer, 0, min(buffer.size, remainingData))
mac.update(buffer, 0, read)
digest.update(buffer, 0, read)
remainingData -= read
}
val ourMac = mac.doFinal()
val theirMac = ByteArray(mac.macLength)
Util.readFully(inputStream, theirMac)
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw InvalidMessageException("MAC doesn't match!")
}
val ourDigest = digest.digest(theirMac)
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
throw InvalidMessageException("Digest doesn't match!")
}
} catch (e: IOException) {
throw InvalidMessageException(e)
} catch (e: ArithmeticException) {
throw InvalidMessageException(e)
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
}
}
}
} }

View File

@@ -0,0 +1,148 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.crypto
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import javax.annotation.Nonnull
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.IllegalBlockSizeException
import javax.crypto.ShortBufferException
import kotlin.math.min
/**
* This is similar to [javax.crypto.CipherInputStream], but it fixes various issues, including proper error propagation,
* and proper handling of boundary conditions.
*/
class BetterCipherInputStream(
inputStream: InputStream,
val cipher: Cipher
) : FilterInputStream(inputStream) {
private var done = false
private var overflowBuffer: ByteArray? = null
@Throws(IOException::class)
override fun read(): Int {
val buffer = ByteArray(1)
var read: Int = read(buffer)
while (read == 0) {
read = read(buffer)
}
if (read == -1) {
return read
}
return buffer[0].toInt() and 0xFF
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
return if (!done) {
readIncremental(buffer, offset, length)
} else {
-1
}
}
override fun markSupported(): Boolean = false
@Throws(IOException::class)
override fun skip(byteCount: Long): Long {
val buffer = ByteArray(4096)
var skipped = 0L
while (skipped < byteCount) {
val remaining = byteCount - skipped
val read = read(buffer, 0, remaining.toInt())
skipped += read.toLong()
}
return skipped
}
@Throws(IOException::class)
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
var offset = originalOffset
var length = originalLength
var readLength = 0
overflowBuffer?.let { overflow ->
if (overflow.size > length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
return length
} else if (overflow.size == length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
overflowBuffer = null
return length
} else {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
readLength += overflow.size
offset += readLength
length -= readLength
overflowBuffer = null
}
}
val ciphertextBuffer = ByteArray(length)
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextBuffer.size)
if (ciphertextRead == -1) {
return readFinal(outputBuffer, offset, length)
}
try {
var plaintextLength = cipher.getOutputSize(ciphertextRead)
if (plaintextLength <= length) {
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
return readLength
}
val plaintextBuffer = ByteArray(plaintextLength)
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
if (plaintextLength <= length) {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
readLength += plaintextLength
} else {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
readLength += length
}
return readLength
} catch (e: ShortBufferException) {
throw AssertionError(e)
}
}
@Throws(IOException::class)
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
try {
val internal = ByteArray(buffer.size)
val actualLength = min(length, cipher.doFinal(internal, 0))
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
done = true
return actualLength
} catch (e: IllegalBlockSizeException) {
throw IOException(e)
} catch (e: BadPaddingException) {
throw IOException(e)
} catch (e: ShortBufferException) {
throw IOException(e)
}
}
}

View File

@@ -161,32 +161,6 @@ class AttachmentCipherTest {
} }
} }
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_nonIncremental() {
attachment_decryptFailOnNullDigest(incremental = false)
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_incremental() {
attachment_decryptFailOnNullDigest(incremental = true)
}
private fun attachment_decryptFailOnNullDigest(incremental: Boolean) {
var cipherFile: File? = null
try {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val encryptResult = encryptData(plaintextInput, key, incremental)
cipherFile = writeToFile(encryptResult.ciphertext)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
} finally {
cipherFile?.delete()
}
}
@Test(expected = InvalidMessageException::class) @Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_nonIncremental() { fun attachment_decryptFailOnBadDigest_nonIncremental() {
attachment_decryptFailOnBadDigest(incremental = false) attachment_decryptFailOnBadDigest(incremental = false)
@@ -293,7 +267,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false) val encryptResult = encryptData(plaintextInput, key, false)
val cipherFile = writeToFile(encryptResult.ciphertext) val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) val inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
val plaintextOutput = readInputStreamFully(inputStream) val plaintextOutput = readInputStreamFully(inputStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput) assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -310,7 +284,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false) val encryptResult = encryptData(plaintextInput, key, false)
val cipherFile = writeToFile(encryptResult.ciphertext) val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
val plaintextOutput = readInputStreamFully(inputStream) val plaintextOutput = readInputStreamFully(inputStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput) assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -332,7 +306,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false) val encryptResult = encryptData(plaintextInput, key, false)
cipherFile = writeToFile(encryptResult.ciphertext) cipherFile = writeToFile(encryptResult.ciphertext)
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
} catch (e: InvalidMessageException) { } catch (e: InvalidMessageException) {
hitCorrectException = true hitCorrectException = true
} finally { } finally {
@@ -368,7 +342,7 @@ class AttachmentCipherTest {
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val keyMaterial = createMediaKeyMaterial(key) val keyMaterial = createMediaKeyMaterial(key)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, length.toLong()) val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, length.toLong())
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
Assert.assertArrayEquals(plaintextInput, plaintextOutput) Assert.assertArrayEquals(plaintextInput, plaintextOutput)
@@ -377,6 +351,60 @@ class AttachmentCipherTest {
} }
} }
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental)
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@Test @Test
fun archive_decryptFailOnBadMac() { fun archive_decryptFailOnBadMac() {
var cipherFile: File? = null var cipherFile: File? = null
@@ -393,7 +421,7 @@ class AttachmentCipherTest {
cipherFile = writeToFile(badMacCiphertext) cipherFile = writeToFile(badMacCiphertext)
val keyMaterial = createMediaKeyMaterial(key) val keyMaterial = createMediaKeyMaterial(key)
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
Assert.fail() Assert.fail()
} catch (e: InvalidMessageException) { } catch (e: InvalidMessageException) {
hitCorrectException = true hitCorrectException = true