mirror of
https://github.com/signalapp/Signal-Android.git
synced 2025-12-23 12:38:33 +00:00
Refactor AttachmentCipherInputStream.
This commit is contained in:
committed by
Michelle Tang
parent
9798f5cc7c
commit
4f6a5de227
@@ -199,7 +199,7 @@ class AttachmentTableTest {
|
||||
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first()
|
||||
|
||||
// Give data to attachment table
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false)
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4)
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
|
||||
|
||||
// Verify the digest has been updated to the properly padded one
|
||||
@@ -230,7 +230,7 @@ class AttachmentTableTest {
|
||||
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
|
||||
|
||||
// Give data to attachment table
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false)
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4)
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
|
||||
|
||||
// Verify the digest hasn't changed
|
||||
|
||||
@@ -252,7 +252,6 @@ class RestoreAttachmentJob private constructor(
|
||||
pointer,
|
||||
attachmentFile,
|
||||
maxReceiveSize,
|
||||
false,
|
||||
progressListener
|
||||
)
|
||||
} else {
|
||||
|
||||
@@ -132,14 +132,13 @@ class RestoreAttachmentThumbnailJob private constructor(
|
||||
|
||||
Log.i(TAG, "Downloading thumbnail for $attachmentId")
|
||||
val downloadResult = AppDependencies.signalServiceMessageReceiver
|
||||
.retrieveArchivedAttachment(
|
||||
.retrieveArchivedThumbnail(
|
||||
SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireThumbnailMediaName()),
|
||||
cdnCredentials,
|
||||
thumbnailTransferFile,
|
||||
pointer,
|
||||
thumbnailFile,
|
||||
maxThumbnailSize,
|
||||
true,
|
||||
progressListener
|
||||
)
|
||||
|
||||
|
||||
@@ -149,7 +149,15 @@ class RestoreLocalAttachmentJob private constructor(
|
||||
try {
|
||||
val iv = ByteArray(16)
|
||||
streamSupplier.openStream().use { StreamUtil.readFully(it, iv) }
|
||||
AttachmentCipherInputStream.createForAttachment(streamSupplier, size, attachment.size, combinedKey, attachment.remoteDigest, null, 0, false).use { input ->
|
||||
AttachmentCipherInputStream.createForAttachment(
|
||||
streamSupplier = streamSupplier,
|
||||
streamLength = size,
|
||||
plaintextLength = attachment.size,
|
||||
combinedKeyMaterial = combinedKey,
|
||||
digest = attachment.remoteDigest,
|
||||
incrementalDigest = null,
|
||||
incrementalMacChunkSize = 0
|
||||
).use { input ->
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(attachment.mmsId, attachment.attachmentId, input, iv)
|
||||
}
|
||||
} catch (e: InvalidMessageException) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Optional;
|
||||
@@ -37,7 +38,12 @@ class AttachmentStreamLocalUriFetcher implements DataFetcher<InputStream> {
|
||||
public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) {
|
||||
try {
|
||||
if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!");
|
||||
is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get(), null, 0);
|
||||
is = AttachmentCipherInputStream.createForAttachment(attachment,
|
||||
plaintextLength,
|
||||
key,
|
||||
digest.get(),
|
||||
null,
|
||||
0);
|
||||
callback.onDataReady(is);
|
||||
} catch (IOException | InvalidMessageException e) {
|
||||
callback.onLoadFailed(e);
|
||||
|
||||
@@ -71,7 +71,8 @@ class PartDataSource implements DataSource {
|
||||
final boolean hasData = attachment.hasData;
|
||||
|
||||
if (inProgress && !hasData && hasIncrementalDigest && attachmentKey != null) {
|
||||
final byte[] decode = Base64.decode(attachmentKey);
|
||||
final byte[] decodedKey = Base64.decode(attachmentKey);
|
||||
|
||||
if (attachment.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) {
|
||||
final File archiveFile = attachmentDatabase.getOrCreateArchiveTransferFile(attachment.attachmentId);
|
||||
try {
|
||||
@@ -81,7 +82,11 @@ class PartDataSource implements DataSource {
|
||||
MediaRootBackupKey.MediaKeyMaterial mediaKeyMaterial = SignalStore.backup().getMediaRootBackupKey().deriveMediaSecretsFromMediaId(mediaId);
|
||||
long originalCipherLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
|
||||
|
||||
this.inputStream = AttachmentCipherInputStream.createStreamingForArchivedAttachment(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, attachment.remoteDigest, decode, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
|
||||
if (attachment.remoteDigest == null) {
|
||||
throw new InvalidMessageException("Missing digest!");
|
||||
}
|
||||
|
||||
this.inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
|
||||
} catch (InvalidMessageException e) {
|
||||
throw new IOException("Error decrypting attachment stream!", e);
|
||||
}
|
||||
@@ -95,7 +100,7 @@ class PartDataSource implements DataSource {
|
||||
throw new InvalidMessageException("Missing digest!");
|
||||
}
|
||||
|
||||
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false);
|
||||
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
|
||||
} catch (InvalidMessageException e) {
|
||||
throw new IOException("Error decrypting attachment stream!", e);
|
||||
}
|
||||
|
||||
@@ -7,9 +7,6 @@
|
||||
package org.whispersystems.signalservice.api;
|
||||
|
||||
import org.signal.core.util.StreamUtil;
|
||||
import org.signal.core.util.concurrent.FutureTransformers;
|
||||
import org.signal.core.util.concurrent.ListenableFuture;
|
||||
import org.signal.core.util.concurrent.SettableFuture;
|
||||
import org.signal.core.util.stream.LimitedInputStream;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
|
||||
@@ -18,24 +15,15 @@ import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
|
||||
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
|
||||
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess;
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceStickerManifest;
|
||||
import org.whispersystems.signalservice.api.profiles.ProfileAndCredential;
|
||||
import org.whispersystems.signalservice.api.profiles.SignalServiceProfile;
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
|
||||
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
|
||||
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
|
||||
import org.whispersystems.signalservice.internal.ServiceResponse;
|
||||
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
|
||||
import org.whispersystems.signalservice.internal.push.IdentityCheckRequest;
|
||||
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse;
|
||||
import org.whispersystems.signalservice.internal.push.PushServiceSocket;
|
||||
import org.whispersystems.signalservice.internal.sticker.Pack;
|
||||
import org.whispersystems.signalservice.internal.util.Util;
|
||||
import org.whispersystems.signalservice.internal.websocket.ResponseMapper;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
@@ -46,15 +34,11 @@ import java.time.ZonedDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import javax.annotation.Nonnull;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
import io.reactivex.rxjava3.core.Single;
|
||||
|
||||
/**
|
||||
* The primary interface for receiving Signal Service messages.
|
||||
*
|
||||
@@ -118,6 +102,7 @@ public class SignalServiceMessageReceiver {
|
||||
public AttachmentDownloadResult retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener)
|
||||
throws IOException, InvalidMessageException, MissingConfigurationException {
|
||||
if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!");
|
||||
if (pointer.getKey() == null) throw new InvalidMessageException("No key!");
|
||||
|
||||
socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
|
||||
|
||||
@@ -127,7 +112,14 @@ public class SignalServiceMessageReceiver {
|
||||
}
|
||||
|
||||
return new AttachmentDownloadResult(
|
||||
AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0),
|
||||
AttachmentCipherInputStream.createForAttachment(
|
||||
destination,
|
||||
pointer.getSize().orElse(0),
|
||||
pointer.getKey(),
|
||||
pointer.getDigest().get(),
|
||||
null,
|
||||
0
|
||||
),
|
||||
iv
|
||||
);
|
||||
}
|
||||
@@ -150,14 +142,17 @@ public class SignalServiceMessageReceiver {
|
||||
@Nonnull SignalServiceAttachmentPointer pointer,
|
||||
@Nonnull File attachmentDestination,
|
||||
long maxSizeBytes,
|
||||
boolean ignoreDigest,
|
||||
@Nullable ProgressListener listener)
|
||||
throws IOException, InvalidMessageException, MissingConfigurationException
|
||||
{
|
||||
if (!ignoreDigest && pointer.getDigest().isEmpty()) {
|
||||
if (pointer.getDigest().isEmpty()) {
|
||||
throw new InvalidMessageException("No attachment digest!");
|
||||
}
|
||||
|
||||
if (pointer.getKey() == null) {
|
||||
throw new InvalidMessageException("No key!");
|
||||
}
|
||||
|
||||
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
|
||||
|
||||
long originalCipherLength = pointer.getSize()
|
||||
@@ -166,7 +161,7 @@ public class SignalServiceMessageReceiver {
|
||||
.orElse(0L);
|
||||
|
||||
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
|
||||
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMedia(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
|
||||
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
|
||||
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
|
||||
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
|
||||
StreamUtil.copy(backupDecrypted, fos);
|
||||
@@ -182,10 +177,63 @@ public class SignalServiceMessageReceiver {
|
||||
attachmentDestination,
|
||||
pointer.getSize().orElse(0),
|
||||
pointer.getKey(),
|
||||
ignoreDigest ? null : pointer.getDigest().get(),
|
||||
pointer.getDigest().get(),
|
||||
null,
|
||||
0,
|
||||
ignoreDigest
|
||||
0
|
||||
);
|
||||
|
||||
return new AttachmentDownloadResult(dataStream, iv);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves an archived media attachment.
|
||||
*
|
||||
* @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media.
|
||||
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
|
||||
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
|
||||
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
|
||||
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
|
||||
* @param listener An optional listener (may be null) to receive callbacks on download progress.
|
||||
*
|
||||
* @return An InputStream that streams the plaintext attachment contents.
|
||||
*/
|
||||
public AttachmentDownloadResult retrieveArchivedThumbnail(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial,
|
||||
@Nonnull Map<String, String> readCredentialHeaders,
|
||||
@Nonnull File archiveDestination,
|
||||
@Nonnull SignalServiceAttachmentPointer pointer,
|
||||
@Nonnull File attachmentDestination,
|
||||
long maxSizeBytes,
|
||||
@Nullable ProgressListener listener)
|
||||
throws IOException, InvalidMessageException, MissingConfigurationException
|
||||
{
|
||||
if (pointer.getKey() == null) {
|
||||
throw new InvalidMessageException("No key!");
|
||||
}
|
||||
|
||||
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
|
||||
|
||||
long originalCipherLength = pointer.getSize()
|
||||
.filter(s -> s > 0)
|
||||
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
|
||||
.orElse(0L);
|
||||
|
||||
// There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer.
|
||||
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
|
||||
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
|
||||
// TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line.
|
||||
StreamUtil.copy(backupDecrypted, fos);
|
||||
}
|
||||
}
|
||||
|
||||
byte[] iv = new byte[16];
|
||||
try (InputStream tempStream = new FileInputStream(attachmentDestination)) {
|
||||
StreamUtil.readFully(tempStream, iv);
|
||||
}
|
||||
|
||||
LimitedInputStream dataStream = AttachmentCipherInputStream.createForArchiveThumbnailInnerLayer(
|
||||
attachmentDestination,
|
||||
pointer.getSize().orElse(0),
|
||||
pointer.getKey()
|
||||
);
|
||||
|
||||
return new AttachmentDownloadResult(dataStream, iv);
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
*/
|
||||
package org.whispersystems.signalservice.api.crypto
|
||||
|
||||
import org.signal.core.util.readNBytesOrThrow
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits
|
||||
import org.signal.libsignal.protocol.InvalidMessageException
|
||||
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
|
||||
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
|
||||
@@ -16,185 +16,316 @@ import org.whispersystems.signalservice.internal.util.Util
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.File
|
||||
import java.io.FileInputStream
|
||||
import java.io.FilterInputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.security.InvalidKeyException
|
||||
import java.security.MessageDigest
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import javax.annotation.Nonnull
|
||||
import javax.crypto.BadPaddingException
|
||||
import javax.crypto.Cipher
|
||||
import javax.crypto.IllegalBlockSizeException
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.ShortBufferException
|
||||
import javax.crypto.spec.IvParameterSpec
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* Class for streaming an encrypted push attachment off disk.
|
||||
* Decrypts an attachment stream that has been encrypted with AES/CBC/PKCS5Padding.
|
||||
*
|
||||
* @author Moxie Marlinspike
|
||||
* It assumes that the first 16 bytes of the stream are the IV, and that the rest of the stream is encrypted data.
|
||||
*/
|
||||
class AttachmentCipherInputStream private constructor(
|
||||
inputStream: InputStream,
|
||||
aesKey: ByteArray,
|
||||
private val totalDataSize: Long
|
||||
) : FilterInputStream(inputStream) {
|
||||
object AttachmentCipherInputStream {
|
||||
|
||||
private val cipher: Cipher
|
||||
private const val BLOCK_SIZE = 16
|
||||
private const val CIPHER_KEY_SIZE = 32
|
||||
private const val MAC_KEY_SIZE = 32
|
||||
|
||||
private var done = false
|
||||
private var totalRead: Long = 0
|
||||
private var overflowBuffer: ByteArray? = null
|
||||
|
||||
init {
|
||||
val iv = ByteArray(BLOCK_SIZE)
|
||||
readFullyWithoutDecrypting(iv)
|
||||
|
||||
this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
|
||||
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
|
||||
/**
|
||||
* Creates a stream to decrypt a typical attachment via a [File].
|
||||
*
|
||||
* @param incrementalDigest If null, incremental mac validation is disabled.
|
||||
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForAttachment(
|
||||
file: File,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray,
|
||||
digest: ByteArray,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int
|
||||
): LimitedInputStream {
|
||||
return create(
|
||||
streamSupplier = { FileInputStream(file) },
|
||||
streamLength = file.length(),
|
||||
plaintextLength = plaintextLength,
|
||||
combinedKeyMaterial = combinedKeyMaterial,
|
||||
digest = digest,
|
||||
incrementalDigest = incrementalDigest,
|
||||
incrementalMacChunkSize = incrementalMacChunkSize,
|
||||
ignoreDigest = false
|
||||
)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(): Int {
|
||||
val buffer = ByteArray(1)
|
||||
var read: Int = read(buffer)
|
||||
while (read == 0) {
|
||||
read = read(buffer)
|
||||
/**
|
||||
* Creates a stream to decrypt a typical attachment via a [StreamSupplier].
|
||||
*
|
||||
* @param incrementalDigest If null, incremental mac validation is disabled.
|
||||
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForAttachment(
|
||||
streamSupplier: StreamSupplier,
|
||||
streamLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray,
|
||||
digest: ByteArray,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int
|
||||
): LimitedInputStream {
|
||||
return create(
|
||||
streamSupplier = streamSupplier,
|
||||
streamLength = streamLength,
|
||||
plaintextLength = plaintextLength,
|
||||
combinedKeyMaterial = combinedKeyMaterial,
|
||||
digest = digest,
|
||||
incrementalDigest = incrementalDigest,
|
||||
incrementalMacChunkSize = incrementalMacChunkSize,
|
||||
ignoreDigest = false
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* After removing the server layer of encryption using [createForArchivedMediaOuterLayer], use this to decrypt the inner layer of the attachment.
|
||||
* Thumbnails have a special path because we don't do any additional digest/hash validation on them.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForArchiveThumbnailInnerLayer(
|
||||
file: File,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray
|
||||
): LimitedInputStream {
|
||||
return create(
|
||||
streamSupplier = { FileInputStream(file) },
|
||||
streamLength = file.length(),
|
||||
plaintextLength = plaintextLength,
|
||||
combinedKeyMaterial = combinedKeyMaterial,
|
||||
digest = null,
|
||||
incrementalDigest = null,
|
||||
incrementalMacChunkSize = 0,
|
||||
ignoreDigest = true
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
|
||||
* This will return a stream that unwraps the server's layer of encryption, giving you a stream that contains a "normally-encrypted" attachment.
|
||||
*
|
||||
* Because we're validating the encryptedDigest/plaintextHash of the inner layer, there's no additional out-of-band validation of this outer layer.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForArchivedMediaOuterLayer(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
|
||||
val mac = initMac(archivedMediaKeyMaterial.macKey)
|
||||
|
||||
if (file.length() <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
if (read == -1) {
|
||||
return read
|
||||
FileInputStream(file).use { macVerificationStream ->
|
||||
verifyMac(macVerificationStream, file.length(), mac, null)
|
||||
}
|
||||
|
||||
return buffer[0].toInt() and 0xFF
|
||||
val encryptedStream = FileInputStream(file)
|
||||
val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, file.length() - mac.macLength)
|
||||
val cipher = createCipher(encryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey)
|
||||
val inputStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
|
||||
|
||||
return LimitedInputStream(inputStream, originalCipherTextLength)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(@Nonnull buffer: ByteArray): Int {
|
||||
return read(buffer, 0, buffer.size)
|
||||
}
|
||||
/**
|
||||
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
|
||||
*
|
||||
* This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it.
|
||||
*
|
||||
* @param incrementalDigest If null, incremental mac validation is disabled.
|
||||
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForArchivedMediaOuterAndInnerLayers(
|
||||
archivedMediaKeyMaterial: MediaKeyMaterial,
|
||||
file: File,
|
||||
originalCipherTextLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray,
|
||||
digest: ByteArray,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int
|
||||
): LimitedInputStream {
|
||||
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
|
||||
val mac = initMac(keyMaterial.macKey)
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
|
||||
return if (totalRead != totalDataSize) {
|
||||
readIncremental(buffer, offset, length)
|
||||
} else if (!done) {
|
||||
readFinal(buffer, offset, length)
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
override fun markSupported(): Boolean = false
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun skip(byteCount: Long): Long {
|
||||
var skipped = 0L
|
||||
while (skipped < byteCount) {
|
||||
val remaining = byteCount - skipped
|
||||
val buffer = ByteArray(min(4096, remaining.toInt()))
|
||||
val read = read(buffer)
|
||||
|
||||
skipped += read.toLong()
|
||||
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
return skipped
|
||||
return create(
|
||||
streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) },
|
||||
streamLength = originalCipherTextLength,
|
||||
plaintextLength = plaintextLength,
|
||||
combinedKeyMaterial = combinedKeyMaterial,
|
||||
digest = digest,
|
||||
incrementalDigest = incrementalDigest,
|
||||
incrementalMacChunkSize = incrementalMacChunkSize,
|
||||
ignoreDigest = false
|
||||
)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
|
||||
var offset = originalOffset
|
||||
var length = originalLength
|
||||
var readLength = 0
|
||||
/**
|
||||
* Creates a stream to decrypt sticker data. Stickers have a special path because the key material is derived from the pack key.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForStickerData(data: ByteArray, packKey: ByteArray): InputStream {
|
||||
val keyMaterial = CombinedKeyMaterial.from(HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64))
|
||||
val mac = initMac(keyMaterial.macKey)
|
||||
|
||||
overflowBuffer?.let { overflow ->
|
||||
if (overflow.size > length) {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
|
||||
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
|
||||
return length
|
||||
} else if (overflow.size == length) {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
|
||||
overflowBuffer = null
|
||||
return length
|
||||
} else {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
|
||||
readLength += overflow.size
|
||||
offset += readLength
|
||||
length -= readLength
|
||||
overflowBuffer = null
|
||||
if (data.size <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
ByteArrayInputStream(data).use { inputStream ->
|
||||
verifyMac(inputStream, data.size.toLong(), mac, null)
|
||||
}
|
||||
|
||||
val encryptedStream = ByteArrayInputStream(data)
|
||||
val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, data.size.toLong() - mac.macLength)
|
||||
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
|
||||
|
||||
return BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
private fun create(
|
||||
streamSupplier: StreamSupplier,
|
||||
streamLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray,
|
||||
digest: ByteArray?,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int,
|
||||
ignoreDigest: Boolean
|
||||
): LimitedInputStream {
|
||||
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
|
||||
val mac = initMac(keyMaterial.macKey)
|
||||
|
||||
if (streamLength <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
|
||||
}
|
||||
|
||||
if (!ignoreDigest && digest == null) {
|
||||
throw InvalidMessageException("Missing digest!")
|
||||
}
|
||||
|
||||
val wrappedStream: InputStream
|
||||
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
|
||||
|
||||
if (!hasIncrementalMac) {
|
||||
streamSupplier.openStream().use { macVerificationStream ->
|
||||
verifyMac(macVerificationStream, streamLength, mac, digest)
|
||||
}
|
||||
}
|
||||
|
||||
if (length + totalRead > totalDataSize) {
|
||||
length = (totalDataSize - totalRead).toInt()
|
||||
}
|
||||
|
||||
val ciphertextBuffer = ByteArray(length)
|
||||
val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) {
|
||||
ciphertextBuffer.size
|
||||
wrappedStream = streamSupplier.openStream()
|
||||
} else {
|
||||
// Ensure we leave the final block for readFinal()
|
||||
ciphertextBuffer.size - cipher.blockSize
|
||||
}
|
||||
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength)
|
||||
totalRead += ciphertextRead.toLong()
|
||||
if (digest == null) {
|
||||
throw InvalidMessageException("Missing digest for incremental mac validation!")
|
||||
}
|
||||
|
||||
wrappedStream = IncrementalMacInputStream(
|
||||
IncrementalMacAdditionalValidationsInputStream(
|
||||
wrapped = streamSupplier.openStream(),
|
||||
fileLength = streamLength,
|
||||
mac = mac,
|
||||
theirDigest = digest
|
||||
),
|
||||
keyMaterial.macKey,
|
||||
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
|
||||
incrementalDigest
|
||||
)
|
||||
}
|
||||
|
||||
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
|
||||
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
|
||||
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
|
||||
|
||||
return LimitedInputStream(decryptingStream, plaintextLength)
|
||||
}
|
||||
|
||||
private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher {
|
||||
val iv = inputStream.readNBytesOrThrow(BLOCK_SIZE)
|
||||
|
||||
return Cipher.getInstance("AES/CBC/PKCS5Padding").apply {
|
||||
init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
|
||||
}
|
||||
}
|
||||
|
||||
private fun initMac(key: ByteArray): Mac {
|
||||
try {
|
||||
var plaintextLength = cipher.getOutputSize(ciphertextRead)
|
||||
|
||||
if (plaintextLength <= length) {
|
||||
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
|
||||
return readLength
|
||||
}
|
||||
|
||||
val plaintextBuffer = ByteArray(plaintextLength)
|
||||
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
|
||||
if (plaintextLength <= length) {
|
||||
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
|
||||
readLength += plaintextLength
|
||||
} else {
|
||||
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
|
||||
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
|
||||
readLength += length
|
||||
}
|
||||
return readLength
|
||||
} catch (e: ShortBufferException) {
|
||||
val mac = Mac.getInstance("HmacSHA256")
|
||||
mac.init(SecretKeySpec(key, "HmacSHA256"))
|
||||
return mac
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
} catch (e: InvalidKeyException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
|
||||
@Throws(InvalidMessageException::class)
|
||||
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
|
||||
try {
|
||||
val internal = ByteArray(buffer.size)
|
||||
val actualLength = min(length, cipher.doFinal(internal, 0))
|
||||
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
|
||||
val digest = MessageDigest.getInstance("SHA256")
|
||||
var remainingData = Util.toIntExact(length) - mac.macLength
|
||||
val buffer = ByteArray(4096)
|
||||
|
||||
done = true
|
||||
return actualLength
|
||||
} catch (e: IllegalBlockSizeException) {
|
||||
throw IOException(e)
|
||||
} catch (e: BadPaddingException) {
|
||||
throw IOException(e)
|
||||
} catch (e: ShortBufferException) {
|
||||
throw IOException(e)
|
||||
while (remainingData > 0) {
|
||||
val read = inputStream.read(buffer, 0, min(buffer.size, remainingData))
|
||||
mac.update(buffer, 0, read)
|
||||
digest.update(buffer, 0, read)
|
||||
remainingData -= read
|
||||
}
|
||||
|
||||
val ourMac = mac.doFinal()
|
||||
val theirMac = ByteArray(mac.macLength)
|
||||
Util.readFully(inputStream, theirMac)
|
||||
|
||||
if (!MessageDigest.isEqual(ourMac, theirMac)) {
|
||||
throw InvalidMessageException("MAC doesn't match!")
|
||||
}
|
||||
|
||||
val ourDigest = digest.digest(theirMac)
|
||||
|
||||
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
|
||||
throw InvalidMessageException("Digest doesn't match!")
|
||||
}
|
||||
} catch (e: IOException) {
|
||||
throw InvalidMessageException(e)
|
||||
} catch (e: ArithmeticException) {
|
||||
throw InvalidMessageException(e)
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readFullyWithoutDecrypting(buffer: ByteArray) {
|
||||
var offset = 0
|
||||
|
||||
while (true) {
|
||||
val read = super.read(buffer, offset, buffer.size - offset)
|
||||
|
||||
if (read + offset < buffer.size) {
|
||||
offset += read
|
||||
} else {
|
||||
return
|
||||
private class CombinedKeyMaterial(val aesKey: ByteArray, val macKey: ByteArray) {
|
||||
companion object {
|
||||
fun from(combinedKeyMaterial: ByteArray): CombinedKeyMaterial {
|
||||
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
|
||||
return CombinedKeyMaterial(parts[0], parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -204,207 +335,4 @@ class AttachmentCipherInputStream private constructor(
|
||||
@Throws(IOException::class)
|
||||
fun openStream(): InputStream
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val BLOCK_SIZE = 16
|
||||
private const val CIPHER_KEY_SIZE = 32
|
||||
private const val MAC_KEY_SIZE = 32
|
||||
|
||||
/**
|
||||
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
|
||||
*
|
||||
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
|
||||
*/
|
||||
@JvmStatic
|
||||
@JvmOverloads
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream {
|
||||
return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest)
|
||||
}
|
||||
|
||||
/**
|
||||
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
|
||||
*
|
||||
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForAttachment(
|
||||
streamSupplier: StreamSupplier,
|
||||
streamLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray?,
|
||||
digest: ByteArray?,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int,
|
||||
ignoreDigest: Boolean
|
||||
): LimitedInputStream {
|
||||
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
|
||||
val mac = initMac(parts[1])
|
||||
|
||||
if (streamLength <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
|
||||
}
|
||||
|
||||
if (!ignoreDigest && digest == null) {
|
||||
throw InvalidMessageException("Missing digest!")
|
||||
}
|
||||
|
||||
val wrappedStream: InputStream
|
||||
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
|
||||
|
||||
if (!hasIncrementalMac) {
|
||||
streamSupplier.openStream().use { macVerificationStream ->
|
||||
verifyMac(macVerificationStream, streamLength, mac, digest)
|
||||
}
|
||||
wrappedStream = streamSupplier.openStream()
|
||||
} else {
|
||||
wrappedStream = IncrementalMacInputStream(
|
||||
IncrementalMacAdditionalValidationsInputStream(
|
||||
streamSupplier.openStream(),
|
||||
streamLength,
|
||||
mac,
|
||||
digest!!
|
||||
),
|
||||
parts[1],
|
||||
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
|
||||
incrementalDigest
|
||||
)
|
||||
}
|
||||
val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength)
|
||||
|
||||
return LimitedInputStream(inputStream, plaintextLength)
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt archived media to it's original attachment encrypted blob.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
|
||||
val mac = initMac(archivedMediaKeyMaterial.macKey)
|
||||
|
||||
if (file.length() <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
FileInputStream(file).use { macVerificationStream ->
|
||||
verifyMac(macVerificationStream, file.length(), mac, null)
|
||||
}
|
||||
val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength)
|
||||
|
||||
return LimitedInputStream(inputStream, originalCipherTextLength)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createStreamingForArchivedAttachment(
|
||||
archivedMediaKeyMaterial: MediaKeyMaterial,
|
||||
file: File,
|
||||
originalCipherTextLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray?,
|
||||
digest: ByteArray,
|
||||
incrementalDigest: ByteArray?,
|
||||
incrementalMacChunkSize: Int
|
||||
): LimitedInputStream {
|
||||
val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength)
|
||||
|
||||
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
|
||||
val mac = initMac(parts[1])
|
||||
|
||||
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
val wrappedStream: InputStream = IncrementalMacInputStream(
|
||||
IncrementalMacAdditionalValidationsInputStream(
|
||||
wrapped = archiveStream,
|
||||
fileLength = file.length(),
|
||||
mac = mac,
|
||||
theirDigest = digest
|
||||
),
|
||||
parts[1],
|
||||
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
|
||||
incrementalDigest
|
||||
)
|
||||
|
||||
val inputStream: InputStream = AttachmentCipherInputStream(
|
||||
inputStream = wrappedStream,
|
||||
aesKey = parts[0],
|
||||
totalDataSize = file.length() - BLOCK_SIZE - mac.macLength
|
||||
)
|
||||
|
||||
return if (plaintextLength != 0L) {
|
||||
LimitedInputStream(inputStream, plaintextLength)
|
||||
} else {
|
||||
withoutLimits(inputStream)
|
||||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream {
|
||||
val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64)
|
||||
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
|
||||
val mac = initMac(parts[1])
|
||||
|
||||
if (data.size <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead!")
|
||||
}
|
||||
|
||||
ByteArrayInputStream(data).use { inputStream ->
|
||||
verifyMac(inputStream, data.size.toLong(), mac, null)
|
||||
}
|
||||
return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong())
|
||||
}
|
||||
|
||||
private fun initMac(key: ByteArray): Mac {
|
||||
try {
|
||||
val mac = Mac.getInstance("HmacSHA256")
|
||||
mac.init(SecretKeySpec(key, "HmacSHA256"))
|
||||
return mac
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
} catch (e: InvalidKeyException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(InvalidMessageException::class)
|
||||
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
|
||||
try {
|
||||
val digest = MessageDigest.getInstance("SHA256")
|
||||
var remainingData = Util.toIntExact(length) - mac.macLength
|
||||
val buffer = ByteArray(4096)
|
||||
|
||||
while (remainingData > 0) {
|
||||
val read = inputStream.read(buffer, 0, min(buffer.size, remainingData))
|
||||
mac.update(buffer, 0, read)
|
||||
digest.update(buffer, 0, read)
|
||||
remainingData -= read
|
||||
}
|
||||
|
||||
val ourMac = mac.doFinal()
|
||||
val theirMac = ByteArray(mac.macLength)
|
||||
Util.readFully(inputStream, theirMac)
|
||||
|
||||
if (!MessageDigest.isEqual(ourMac, theirMac)) {
|
||||
throw InvalidMessageException("MAC doesn't match!")
|
||||
}
|
||||
|
||||
val ourDigest = digest.digest(theirMac)
|
||||
|
||||
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
|
||||
throw InvalidMessageException("Digest doesn't match!")
|
||||
}
|
||||
} catch (e: IOException) {
|
||||
throw InvalidMessageException(e)
|
||||
} catch (e: ArithmeticException) {
|
||||
throw InvalidMessageException(e)
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.signalservice.api.crypto
|
||||
|
||||
import java.io.FilterInputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import javax.annotation.Nonnull
|
||||
import javax.crypto.BadPaddingException
|
||||
import javax.crypto.Cipher
|
||||
import javax.crypto.IllegalBlockSizeException
|
||||
import javax.crypto.ShortBufferException
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* This is similar to [javax.crypto.CipherInputStream], but it fixes various issues, including proper error propagation,
|
||||
* and proper handling of boundary conditions.
|
||||
*/
|
||||
class BetterCipherInputStream(
|
||||
inputStream: InputStream,
|
||||
val cipher: Cipher
|
||||
) : FilterInputStream(inputStream) {
|
||||
|
||||
private var done = false
|
||||
private var overflowBuffer: ByteArray? = null
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(): Int {
|
||||
val buffer = ByteArray(1)
|
||||
var read: Int = read(buffer)
|
||||
while (read == 0) {
|
||||
read = read(buffer)
|
||||
}
|
||||
|
||||
if (read == -1) {
|
||||
return read
|
||||
}
|
||||
|
||||
return buffer[0].toInt() and 0xFF
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(@Nonnull buffer: ByteArray): Int {
|
||||
return read(buffer, 0, buffer.size)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
|
||||
return if (!done) {
|
||||
readIncremental(buffer, offset, length)
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
override fun markSupported(): Boolean = false
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun skip(byteCount: Long): Long {
|
||||
val buffer = ByteArray(4096)
|
||||
var skipped = 0L
|
||||
|
||||
while (skipped < byteCount) {
|
||||
val remaining = byteCount - skipped
|
||||
val read = read(buffer, 0, remaining.toInt())
|
||||
|
||||
skipped += read.toLong()
|
||||
}
|
||||
|
||||
return skipped
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
|
||||
var offset = originalOffset
|
||||
var length = originalLength
|
||||
var readLength = 0
|
||||
|
||||
overflowBuffer?.let { overflow ->
|
||||
if (overflow.size > length) {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
|
||||
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
|
||||
return length
|
||||
} else if (overflow.size == length) {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
|
||||
overflowBuffer = null
|
||||
return length
|
||||
} else {
|
||||
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
|
||||
readLength += overflow.size
|
||||
offset += readLength
|
||||
length -= readLength
|
||||
overflowBuffer = null
|
||||
}
|
||||
}
|
||||
|
||||
val ciphertextBuffer = ByteArray(length)
|
||||
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextBuffer.size)
|
||||
|
||||
if (ciphertextRead == -1) {
|
||||
return readFinal(outputBuffer, offset, length)
|
||||
}
|
||||
|
||||
try {
|
||||
var plaintextLength = cipher.getOutputSize(ciphertextRead)
|
||||
|
||||
if (plaintextLength <= length) {
|
||||
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
|
||||
return readLength
|
||||
}
|
||||
|
||||
val plaintextBuffer = ByteArray(plaintextLength)
|
||||
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
|
||||
if (plaintextLength <= length) {
|
||||
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
|
||||
readLength += plaintextLength
|
||||
} else {
|
||||
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
|
||||
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
|
||||
readLength += length
|
||||
}
|
||||
return readLength
|
||||
} catch (e: ShortBufferException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
|
||||
try {
|
||||
val internal = ByteArray(buffer.size)
|
||||
val actualLength = min(length, cipher.doFinal(internal, 0))
|
||||
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
|
||||
|
||||
done = true
|
||||
return actualLength
|
||||
} catch (e: IllegalBlockSizeException) {
|
||||
throw IOException(e)
|
||||
} catch (e: BadPaddingException) {
|
||||
throw IOException(e)
|
||||
} catch (e: ShortBufferException) {
|
||||
throw IOException(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,32 +161,6 @@ class AttachmentCipherTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = InvalidMessageException::class)
|
||||
fun attachment_decryptFailOnNullDigest_nonIncremental() {
|
||||
attachment_decryptFailOnNullDigest(incremental = false)
|
||||
}
|
||||
|
||||
@Test(expected = InvalidMessageException::class)
|
||||
fun attachment_decryptFailOnNullDigest_incremental() {
|
||||
attachment_decryptFailOnNullDigest(incremental = true)
|
||||
}
|
||||
|
||||
private fun attachment_decryptFailOnNullDigest(incremental: Boolean) {
|
||||
var cipherFile: File? = null
|
||||
|
||||
try {
|
||||
val key = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
|
||||
val encryptResult = encryptData(plaintextInput, key, incremental)
|
||||
|
||||
cipherFile = writeToFile(encryptResult.ciphertext)
|
||||
|
||||
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
|
||||
} finally {
|
||||
cipherFile?.delete()
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = InvalidMessageException::class)
|
||||
fun attachment_decryptFailOnBadDigest_nonIncremental() {
|
||||
attachment_decryptFailOnBadDigest(incremental = false)
|
||||
@@ -293,7 +267,7 @@ class AttachmentCipherTest {
|
||||
val encryptResult = encryptData(plaintextInput, key, false)
|
||||
val cipherFile = writeToFile(encryptResult.ciphertext)
|
||||
|
||||
val inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
val inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
val plaintextOutput = readInputStreamFully(inputStream)
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
@@ -310,7 +284,7 @@ class AttachmentCipherTest {
|
||||
val encryptResult = encryptData(plaintextInput, key, false)
|
||||
val cipherFile = writeToFile(encryptResult.ciphertext)
|
||||
|
||||
val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
val plaintextOutput = readInputStreamFully(inputStream)
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
@@ -332,7 +306,7 @@ class AttachmentCipherTest {
|
||||
val encryptResult = encryptData(plaintextInput, key, false)
|
||||
cipherFile = writeToFile(encryptResult.ciphertext)
|
||||
|
||||
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
} catch (e: InvalidMessageException) {
|
||||
hitCorrectException = true
|
||||
} finally {
|
||||
@@ -368,7 +342,7 @@ class AttachmentCipherTest {
|
||||
val cipherFile = writeToFile(encryptedData)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(key)
|
||||
val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, length.toLong())
|
||||
val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, length.toLong())
|
||||
val plaintextOutput = readInputStreamFully(decryptedStream)
|
||||
|
||||
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
|
||||
@@ -377,6 +351,60 @@ class AttachmentCipherTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveInnerAndOuter_encryptDecrypt_incremental() {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() {
|
||||
for (i in 0..99) {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() {
|
||||
// Designed to stress the various boundary conditions of reading the final mac
|
||||
for (i in 0..99) {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
}
|
||||
}
|
||||
|
||||
private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(fileSize)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental)
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
|
||||
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(outerKey)
|
||||
val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
file = cipherFile,
|
||||
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
|
||||
plaintextLength = plaintextInput.size.toLong(),
|
||||
combinedKeyMaterial = innerKey,
|
||||
digest = innerEncryptResult.digest,
|
||||
incrementalDigest = innerEncryptResult.incrementalDigest,
|
||||
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
|
||||
)
|
||||
val plaintextOutput = decryptedStream.readFully(autoClose = false)
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
|
||||
|
||||
cipherFile.delete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archive_decryptFailOnBadMac() {
|
||||
var cipherFile: File? = null
|
||||
@@ -393,7 +421,7 @@ class AttachmentCipherTest {
|
||||
cipherFile = writeToFile(badMacCiphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(key)
|
||||
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong())
|
||||
Assert.fail()
|
||||
} catch (e: InvalidMessageException) {
|
||||
hitCorrectException = true
|
||||
|
||||
Reference in New Issue
Block a user