Refactor AttachmentCipherInputStream.

This commit is contained in:
Greyson Parrelli
2025-06-18 10:35:45 -04:00
committed by Michelle Tang
parent 9798f5cc7c
commit 4f6a5de227
10 changed files with 572 additions and 403 deletions

View File

@@ -199,7 +199,7 @@ class AttachmentTableTest {
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false)
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
// Verify the digest has been updated to the properly padded one
@@ -230,7 +230,7 @@ class AttachmentTableTest {
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false)
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
// Verify the digest hasn't changed

View File

@@ -252,7 +252,6 @@ class RestoreAttachmentJob private constructor(
pointer,
attachmentFile,
maxReceiveSize,
false,
progressListener
)
} else {

View File

@@ -132,14 +132,13 @@ class RestoreAttachmentThumbnailJob private constructor(
Log.i(TAG, "Downloading thumbnail for $attachmentId")
val downloadResult = AppDependencies.signalServiceMessageReceiver
.retrieveArchivedAttachment(
.retrieveArchivedThumbnail(
SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireThumbnailMediaName()),
cdnCredentials,
thumbnailTransferFile,
pointer,
thumbnailFile,
maxThumbnailSize,
true,
progressListener
)

View File

@@ -149,7 +149,15 @@ class RestoreLocalAttachmentJob private constructor(
try {
val iv = ByteArray(16)
streamSupplier.openStream().use { StreamUtil.readFully(it, iv) }
AttachmentCipherInputStream.createForAttachment(streamSupplier, size, attachment.size, combinedKey, attachment.remoteDigest, null, 0, false).use { input ->
AttachmentCipherInputStream.createForAttachment(
streamSupplier = streamSupplier,
streamLength = size,
plaintextLength = attachment.size,
combinedKeyMaterial = combinedKey,
digest = attachment.remoteDigest,
incrementalDigest = null,
incrementalMacChunkSize = 0
).use { input ->
SignalDatabase.attachments.finalizeAttachmentAfterDownload(attachment.mmsId, attachment.attachmentId, input, iv)
}
} catch (e: InvalidMessageException) {

View File

@@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.InvalidMessageException;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;
@@ -37,7 +38,12 @@ class AttachmentStreamLocalUriFetcher implements DataFetcher<InputStream> {
public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) {
try {
if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!");
is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get(), null, 0);
is = AttachmentCipherInputStream.createForAttachment(attachment,
plaintextLength,
key,
digest.get(),
null,
0);
callback.onDataReady(is);
} catch (IOException | InvalidMessageException e) {
callback.onLoadFailed(e);

View File

@@ -71,7 +71,8 @@ class PartDataSource implements DataSource {
final boolean hasData = attachment.hasData;
if (inProgress && !hasData && hasIncrementalDigest && attachmentKey != null) {
final byte[] decode = Base64.decode(attachmentKey);
final byte[] decodedKey = Base64.decode(attachmentKey);
if (attachment.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) {
final File archiveFile = attachmentDatabase.getOrCreateArchiveTransferFile(attachment.attachmentId);
try {
@@ -81,7 +82,11 @@ class PartDataSource implements DataSource {
MediaRootBackupKey.MediaKeyMaterial mediaKeyMaterial = SignalStore.backup().getMediaRootBackupKey().deriveMediaSecretsFromMediaId(mediaId);
long originalCipherLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
this.inputStream = AttachmentCipherInputStream.createStreamingForArchivedAttachment(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, attachment.remoteDigest, decode, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
if (attachment.remoteDigest == null) {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e);
}
@@ -95,7 +100,7 @@ class PartDataSource implements DataSource {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false);
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e);
}

View File

@@ -7,9 +7,6 @@
package org.whispersystems.signalservice.api;
import org.signal.core.util.StreamUtil;
import org.signal.core.util.concurrent.FutureTransformers;
import org.signal.core.util.concurrent.ListenableFuture;
import org.signal.core.util.concurrent.SettableFuture;
import org.signal.core.util.stream.LimitedInputStream;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
@@ -18,24 +15,15 @@ import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceStickerManifest;
import org.whispersystems.signalservice.api.profiles.ProfileAndCredential;
import org.whispersystems.signalservice.api.profiles.SignalServiceProfile;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.IdentityCheckRequest;
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse;
import org.whispersystems.signalservice.internal.push.PushServiceSocket;
import org.whispersystems.signalservice.internal.sticker.Pack;
import org.whispersystems.signalservice.internal.util.Util;
import org.whispersystems.signalservice.internal.websocket.ResponseMapper;
import java.io.File;
import java.io.FileInputStream;
@@ -46,15 +34,11 @@ import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import io.reactivex.rxjava3.core.Single;
/**
* The primary interface for receiving Signal Service messages.
*
@@ -118,6 +102,7 @@ public class SignalServiceMessageReceiver {
public AttachmentDownloadResult retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException {
if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!");
if (pointer.getKey() == null) throw new InvalidMessageException("No key!");
socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
@@ -127,7 +112,14 @@ public class SignalServiceMessageReceiver {
}
return new AttachmentDownloadResult(
AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0),
AttachmentCipherInputStream.createForAttachment(
destination,
pointer.getSize().orElse(0),
pointer.getKey(),
pointer.getDigest().get(),
null,
0
),
iv
);
}
@@ -150,14 +142,17 @@ public class SignalServiceMessageReceiver {
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
boolean ignoreDigest,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
{
if (!ignoreDigest && pointer.getDigest().isEmpty()) {
if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("No attachment digest!");
}
if (pointer.getKey() == null) {
throw new InvalidMessageException("No key!");
}
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize()
@@ -166,7 +161,7 @@ public class SignalServiceMessageReceiver {
.orElse(0L);
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMedia(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
StreamUtil.copy(backupDecrypted, fos);
@@ -182,10 +177,63 @@ public class SignalServiceMessageReceiver {
attachmentDestination,
pointer.getSize().orElse(0),
pointer.getKey(),
ignoreDigest ? null : pointer.getDigest().get(),
pointer.getDigest().get(),
null,
0,
ignoreDigest
0
);
return new AttachmentDownloadResult(dataStream, iv);
}
/**
* Retrieves an archived media attachment.
*
* @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media.
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress.
*
* @return An InputStream that streams the plaintext attachment contents.
*/
public AttachmentDownloadResult retrieveArchivedThumbnail(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial,
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
{
if (pointer.getKey() == null) {
throw new InvalidMessageException("No key!");
}
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize()
.filter(s -> s > 0)
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
.orElse(0L);
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
StreamUtil.copy(backupDecrypted, fos);
}
}
byte[] iv = new byte[16];
try (InputStream tempStream = new FileInputStream(attachmentDestination)) {
StreamUtil.readFully(tempStream, iv);
}
LimitedInputStream dataStream = AttachmentCipherInputStream.createForArchiveThumbnailInnerLayer(
attachmentDestination,
pointer.getSize().orElse(0),
pointer.getKey()
);
return new AttachmentDownloadResult(dataStream, iv);

View File

@@ -5,8 +5,8 @@
*/
package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits
import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
@@ -16,216 +16,62 @@ import org.whispersystems.signalservice.internal.util.Util
import java.io.ByteArrayInputStream
import java.io.File
import java.io.FileInputStream
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import java.security.InvalidKeyException
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import javax.annotation.Nonnull
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.IllegalBlockSizeException
import javax.crypto.Mac
import javax.crypto.ShortBufferException
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import kotlin.math.min
/**
* Class for streaming an encrypted push attachment off disk.
* Decrypts an attachment stream that has been encrypted with AES/CBC/PKCS5Padding.
*
* @author Moxie Marlinspike
* It assumes that the first 16 bytes of the stream are the IV, and that the rest of the stream is encrypted data.
*/
class AttachmentCipherInputStream private constructor(
inputStream: InputStream,
aesKey: ByteArray,
private val totalDataSize: Long
) : FilterInputStream(inputStream) {
object AttachmentCipherInputStream {
private val cipher: Cipher
private var done = false
private var totalRead: Long = 0
private var overflowBuffer: ByteArray? = null
init {
val iv = ByteArray(BLOCK_SIZE)
readFullyWithoutDecrypting(iv)
this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
}
@Throws(IOException::class)
override fun read(): Int {
val buffer = ByteArray(1)
var read: Int = read(buffer)
while (read == 0) {
read = read(buffer)
}
if (read == -1) {
return read
}
return buffer[0].toInt() and 0xFF
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
return if (totalRead != totalDataSize) {
readIncremental(buffer, offset, length)
} else if (!done) {
readFinal(buffer, offset, length)
} else {
-1
}
}
override fun markSupported(): Boolean = false
@Throws(IOException::class)
override fun skip(byteCount: Long): Long {
var skipped = 0L
while (skipped < byteCount) {
val remaining = byteCount - skipped
val buffer = ByteArray(min(4096, remaining.toInt()))
val read = read(buffer)
skipped += read.toLong()
}
return skipped
}
@Throws(IOException::class)
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
var offset = originalOffset
var length = originalLength
var readLength = 0
overflowBuffer?.let { overflow ->
if (overflow.size > length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
return length
} else if (overflow.size == length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
overflowBuffer = null
return length
} else {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
readLength += overflow.size
offset += readLength
length -= readLength
overflowBuffer = null
}
}
if (length + totalRead > totalDataSize) {
length = (totalDataSize - totalRead).toInt()
}
val ciphertextBuffer = ByteArray(length)
val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) {
ciphertextBuffer.size
} else {
// Ensure we leave the final block for readFinal()
ciphertextBuffer.size - cipher.blockSize
}
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength)
totalRead += ciphertextRead.toLong()
try {
var plaintextLength = cipher.getOutputSize(ciphertextRead)
if (plaintextLength <= length) {
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
return readLength
}
val plaintextBuffer = ByteArray(plaintextLength)
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
if (plaintextLength <= length) {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
readLength += plaintextLength
} else {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
readLength += length
}
return readLength
} catch (e: ShortBufferException) {
throw AssertionError(e)
}
}
@Throws(IOException::class)
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
try {
val internal = ByteArray(buffer.size)
val actualLength = min(length, cipher.doFinal(internal, 0))
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
done = true
return actualLength
} catch (e: IllegalBlockSizeException) {
throw IOException(e)
} catch (e: BadPaddingException) {
throw IOException(e)
} catch (e: ShortBufferException) {
throw IOException(e)
}
}
@Throws(IOException::class)
private fun readFullyWithoutDecrypting(buffer: ByteArray) {
var offset = 0
while (true) {
val read = super.read(buffer, offset, buffer.size - offset)
if (read + offset < buffer.size) {
offset += read
} else {
return
}
}
}
fun interface StreamSupplier {
@Nonnull
@Throws(IOException::class)
fun openStream(): InputStream
}
companion object {
private const val BLOCK_SIZE = 16
private const val CIPHER_KEY_SIZE = 32
private const val MAC_KEY_SIZE = 32
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
* Creates a stream to decrypt a typical attachment via a [File].
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
* @param incrementalDigest If null, incremental mac validation is disabled.
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
*/
@JvmStatic
@JvmOverloads
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream {
return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest)
fun createForAttachment(
file: File,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
return create(
streamSupplier = { FileInputStream(file) },
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
}
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
* Creates a stream to decrypt a typical attachment via a [StreamSupplier].
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
* @param incrementalDigest If null, incremental mac validation is disabled.
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
@@ -233,14 +79,150 @@ class AttachmentCipherInputStream private constructor(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
return create(
streamSupplier = streamSupplier,
streamLength = streamLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
}
/**
* After removing the server layer of encryption using [createForArchivedMediaOuterLayer], use this to decrypt the inner layer of the attachment.
* Thumbnails have a special path because we don't do any additional digest/hash validation on them.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchiveThumbnailInnerLayer(
file: File,
plaintextLength: Long,
combinedKeyMaterial: ByteArray
): LimitedInputStream {
return create(
streamSupplier = { FileInputStream(file) },
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = null,
incrementalDigest = null,
incrementalMacChunkSize = 0,
ignoreDigest = true
)
}
/**
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
* This will return a stream that unwraps the server's layer of encryption, giving you a stream that contains a "normally-encrypted" attachment.
*
* Because we're validating the encryptedDigest/plaintextHash of the inner layer, there's no additional out-of-band validation of this outer layer.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMediaOuterLayer(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null)
}
val encryptedStream = FileInputStream(file)
val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, file.length() - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey)
val inputStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
return LimitedInputStream(inputStream, originalCipherTextLength)
}
/**
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
*
* This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it.
*
* @param incrementalDigest If null, incremental mac validation is disabled.
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMediaOuterAndInnerLayers(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
return create(
streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) },
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
)
}
/**
* Creates a stream to decrypt sticker data. Stickers have a special path because the key material is derived from the pack key.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForStickerData(data: ByteArray, packKey: ByteArray): InputStream {
val keyMaterial = CombinedKeyMaterial.from(HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64))
val mac = initMac(keyMaterial.macKey)
if (data.size <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null)
}
val encryptedStream = ByteArrayInputStream(data)
val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, data.size.toLong() - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
return BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
private fun create(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
): LimitedInputStream {
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
if (streamLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
@@ -259,104 +241,36 @@ class AttachmentCipherInputStream private constructor(
}
wrappedStream = streamSupplier.openStream()
} else {
if (digest == null) {
throw InvalidMessageException("Missing digest for incremental mac validation!")
}
wrappedStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
streamSupplier.openStream(),
streamLength,
mac,
digest!!
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
}
val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, plaintextLength)
}
/**
* Decrypt archived media to it's original attachment encrypted blob.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null)
}
val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, originalCipherTextLength)
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createStreamingForArchivedAttachment(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
val wrappedStream: InputStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
wrapped = archiveStream,
fileLength = file.length(),
wrapped = streamSupplier.openStream(),
fileLength = streamLength,
mac = mac,
theirDigest = digest
),
parts[1],
keyMaterial.macKey,
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
val inputStream: InputStream = AttachmentCipherInputStream(
inputStream = wrappedStream,
aesKey = parts[0],
totalDataSize = file.length() - BLOCK_SIZE - mac.macLength
)
return if (plaintextLength != 0L) {
LimitedInputStream(inputStream, plaintextLength)
} else {
withoutLimits(inputStream)
}
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream {
val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
if (data.size <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
return LimitedInputStream(decryptingStream, plaintextLength)
}
ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null)
private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher {
val iv = inputStream.readNBytesOrThrow(BLOCK_SIZE)
return Cipher.getInstance("AES/CBC/PKCS5Padding").apply {
init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
}
return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong())
}
private fun initMac(key: ByteArray): Mac {
@@ -406,5 +320,19 @@ class AttachmentCipherInputStream private constructor(
throw AssertionError(e)
}
}
private class CombinedKeyMaterial(val aesKey: ByteArray, val macKey: ByteArray) {
companion object {
fun from(combinedKeyMaterial: ByteArray): CombinedKeyMaterial {
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
return CombinedKeyMaterial(parts[0], parts[1])
}
}
}
fun interface StreamSupplier {
@Nonnull
@Throws(IOException::class)
fun openStream(): InputStream
}
}

View File

@@ -0,0 +1,148 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.crypto
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import javax.annotation.Nonnull
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.IllegalBlockSizeException
import javax.crypto.ShortBufferException
import kotlin.math.min
/**
* This is similar to [javax.crypto.CipherInputStream], but it fixes various issues, including proper error propagation,
* and proper handling of boundary conditions.
*/
class BetterCipherInputStream(
inputStream: InputStream,
val cipher: Cipher
) : FilterInputStream(inputStream) {
private var done = false
private var overflowBuffer: ByteArray? = null
@Throws(IOException::class)
override fun read(): Int {
val buffer = ByteArray(1)
var read: Int = read(buffer)
while (read == 0) {
read = read(buffer)
}
if (read == -1) {
return read
}
return buffer[0].toInt() and 0xFF
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
return if (!done) {
readIncremental(buffer, offset, length)
} else {
-1
}
}
override fun markSupported(): Boolean = false
@Throws(IOException::class)
override fun skip(byteCount: Long): Long {
val buffer = ByteArray(4096)
var skipped = 0L
while (skipped < byteCount) {
val remaining = byteCount - skipped
val read = read(buffer, 0, remaining.toInt())
skipped += read.toLong()
}
return skipped
}
@Throws(IOException::class)
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
var offset = originalOffset
var length = originalLength
var readLength = 0
overflowBuffer?.let { overflow ->
if (overflow.size > length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
return length
} else if (overflow.size == length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
overflowBuffer = null
return length
} else {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
readLength += overflow.size
offset += readLength
length -= readLength
overflowBuffer = null
}
}
val ciphertextBuffer = ByteArray(length)
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextBuffer.size)
if (ciphertextRead == -1) {
return readFinal(outputBuffer, offset, length)
}
try {
var plaintextLength = cipher.getOutputSize(ciphertextRead)
if (plaintextLength <= length) {
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
return readLength
}
val plaintextBuffer = ByteArray(plaintextLength)
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
if (plaintextLength <= length) {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
readLength += plaintextLength
} else {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
readLength += length
}
return readLength
} catch (e: ShortBufferException) {
throw AssertionError(e)
}
}
@Throws(IOException::class)
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
try {
val internal = ByteArray(buffer.size)
val actualLength = min(length, cipher.doFinal(internal, 0))
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
done = true
return actualLength
} catch (e: IllegalBlockSizeException) {
throw IOException(e)
} catch (e: BadPaddingException) {
throw IOException(e)
} catch (e: ShortBufferException) {
throw IOException(e)
}
}
}

View File

@@ -161,32 +161,6 @@ class AttachmentCipherTest {
}
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_nonIncremental() {
attachment_decryptFailOnNullDigest(incremental = false)
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_incremental() {
attachment_decryptFailOnNullDigest(incremental = true)
}
private fun attachment_decryptFailOnNullDigest(incremental: Boolean) {
var cipherFile: File? = null
try {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val encryptResult = encryptData(plaintextInput, key, incremental)
cipherFile = writeToFile(encryptResult.ciphertext)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
} finally {
cipherFile?.delete()
}
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_nonIncremental() {
attachment_decryptFailOnBadDigest(incremental = false)
@@ -293,7 +267,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
val inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
val plaintextOutput = readInputStreamFully(inputStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -310,7 +284,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
val plaintextOutput = readInputStreamFully(inputStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -332,7 +306,7 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, false)
cipherFile = writeToFile(encryptResult.ciphertext)
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
} catch (e: InvalidMessageException) {
hitCorrectException = true
} finally {
@@ -368,7 +342,7 @@ class AttachmentCipherTest {
val cipherFile = writeToFile(encryptedData)
val keyMaterial = createMediaKeyMaterial(key)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, length.toLong())
val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, length.toLong())
val plaintextOutput = readInputStreamFully(decryptedStream)
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
@@ -377,6 +351,60 @@ class AttachmentCipherTest {
}
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental)
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@Test
fun archive_decryptFailOnBadMac() {
var cipherFile: File? = null
@@ -393,7 +421,7 @@ class AttachmentCipherTest {
cipherFile = writeToFile(badMacCiphertext)
val keyMaterial = createMediaKeyMaterial(key)
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
Assert.fail()
} catch (e: InvalidMessageException) {
hitCorrectException = true