diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt index db240092af..d73a137700 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt @@ -1,5 +1,6 @@ package org.whispersystems.signalservice.api.messages +import okio.ByteString import org.signal.libsignal.protocol.message.DecryptionErrorMessage import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage import org.signal.libsignal.zkgroup.InvalidInputException @@ -13,6 +14,7 @@ import org.whispersystems.signalservice.internal.push.DataMessage import org.whispersystems.signalservice.internal.push.EditMessage import org.whispersystems.signalservice.internal.push.Envelope import org.whispersystems.signalservice.internal.push.GroupContextV2 +import org.whispersystems.signalservice.internal.push.PniSignatureMessage import org.whispersystems.signalservice.internal.push.ReceiptMessage import org.whispersystems.signalservice.internal.push.StoryMessage import org.whispersystems.signalservice.internal.push.SyncMessage @@ -43,6 +45,10 @@ object EnvelopeContentValidator { validateSenderKeyDistributionMessage(content.senderKeyDistributionMessage.toByteArray())?.let { return it } } + if (content.pniSignatureMessage != null) { + validatePniSignatureMessage(content.pniSignatureMessage)?.let { return it } + } + // Reminder: envelope.destinationServiceId was already validated since we need that for decryption return when { @@ -255,6 +261,18 @@ object EnvelopeContentValidator { } } + private fun validatePniSignatureMessage(pniSignatureMessage: PniSignatureMessage): Result? { + if (pniSignatureMessage.pni.isNullOrInvalidPni()) { + return Result.Invalid("[PniSignatureMessage] Invalid PNI") + } + + if (pniSignatureMessage.signature == null) { + return Result.Invalid("[PniSignatureMessage] Signature is null") + } + + return null + } + private fun validateStoryMessage(storyMessage: StoryMessage): Result { if (storyMessage.group != null) { validateGroupContextV2(storyMessage.group, "[StoryMessage]")?.let { return it } @@ -328,6 +346,11 @@ object EnvelopeContentValidator { return parsed == null || parsed.isUnknown } + private fun ByteString?.isNullOrInvalidPni(): Boolean { + val parsed = ServiceId.PNI.parseOrNull(this?.toByteArray()) + return parsed == null || parsed.isUnknown + } + private fun Content?.meetsStoryFlagCriteria(): Boolean { return when { this == null -> false