From 6d2d3ae528af574206308652fe35d2acbec3fd9e Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Wed, 9 Aug 2023 14:54:06 -0400 Subject: [PATCH] Improve ServiceId parsing functions. --- .../securesms/database/RecipientTable.kt | 2 +- .../securesms/migrations/PniMigrationJob.java | 2 +- .../registration/RegistrationRepository.java | 2 +- .../signalservice/api/push/ServiceId.kt | 52 +++++--- .../api/storage/SignalContactRecord.java | 2 +- .../push/PushTransportDetailsTest.java | 7 +- .../signalservice/api/push/ServiceIdTests.kt | 115 ++++++++++++++++++ 7 files changed, 162 insertions(+), 20 deletions(-) rename libsignal/service/src/test/java/org/whispersystems/signalservice/{ => api}/push/PushTransportDetailsTest.java (83%) create mode 100644 libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/ServiceIdTests.kt diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt index 5dff8ba3c8..5bb3ab287d 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt @@ -4154,7 +4154,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da return RecipientRecord( id = recipientId, aci = ACI.parseOrNull(cursor.requireString(ACI_COLUMN)), - pni = PNI.parseOrNull(cursor.requireString(PNI_COLUMN)), + pni = PNI.parsePrefixedOrNull(cursor.requireString(PNI_COLUMN)), username = cursor.requireString(USERNAME), e164 = cursor.requireString(E164), email = cursor.requireString(EMAIL), diff --git a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniMigrationJob.java b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniMigrationJob.java index 45eb12cf5d..6035f6cdc8 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniMigrationJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniMigrationJob.java @@ -47,7 +47,7 @@ public class PniMigrationJob extends MigrationJob { return; } - PNI pni = PNI.parseUnPrefixedOrNull(ApplicationDependencies.getSignalServiceAccountManager().getWhoAmI().getPni()); + PNI pni = PNI.parseOrNull(ApplicationDependencies.getSignalServiceAccountManager().getWhoAmI().getPni()); if (pni == null) { throw new IOException("Invalid PNI!"); diff --git a/app/src/main/java/org/thoughtcrime/securesms/registration/RegistrationRepository.java b/app/src/main/java/org/thoughtcrime/securesms/registration/RegistrationRepository.java index 218242cf70..3f68d4e718 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registration/RegistrationRepository.java +++ b/app/src/main/java/org/thoughtcrime/securesms/registration/RegistrationRepository.java @@ -129,7 +129,7 @@ public final class RegistrationRepository { Preconditions.checkNotNull(response.getPniPreKeyCollection(), "Missing PNI prekey collection!"); ACI aci = ACI.parseOrThrow(response.getVerifyAccountResponse().getUuid()); - PNI pni = PNI.parseUnPrefixedOrThrow(response.getVerifyAccountResponse().getPni()); + PNI pni = PNI.parseOrThrow(response.getVerifyAccountResponse().getPni()); boolean hasPin = response.getVerifyAccountResponse().isStorageCapable(); SignalStore.account().setAci(aci); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/push/ServiceId.kt b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/push/ServiceId.kt index 6adaed5b9e..4fbce5d5ae 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/push/ServiceId.kt +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/push/ServiceId.kt @@ -54,7 +54,11 @@ sealed class ServiceId(val libSignalServiceId: LibSignalServiceId) { } return try { - fromLibSignal(LibSignalServiceId.parseFromBinary(raw)) + if (raw.size == 17) { + fromLibSignal(LibSignalServiceId.parseFromFixedWidthBinary(raw)) + } else { + fromLibSignal(LibSignalServiceId.parseFromBinary(raw)) + } } catch (e: IllegalArgumentException) { null } catch (e: InvalidServiceIdException) { @@ -152,39 +156,57 @@ sealed class ServiceId(val libSignalServiceId: LibSignalServiceId) { @JvmStatic fun from(uuid: UUID): PNI = PNI(LibSignalPni(uuid)) + /** Parses a string as a PNI, regardless if the `PNI:` prefix is present or not. Only use this if you are certain that what you're reading is a PNI. */ @JvmStatic - fun parseOrNull(raw: String?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } - - /** Parses a plain UUID (without the `PNI:` prefix) as a PNI. Be certain that whatever you pass to this is for sure a PNI! */ - @JvmStatic - fun parseUnPrefixedOrNull(raw: String?): PNI? { - val uuid = UuidUtil.parseOrNull(raw) - return if (uuid != null) { - PNI(LibSignalPni(uuid)) - } else { + fun parseOrNull(raw: String?): PNI? { + return if (raw == null) { null + } else if (raw.startsWith("PNI:")) { + return parsePrefixedOrNull(raw) + } else { + val uuid = UuidUtil.parseOrNull(raw) + if (uuid != null) { + PNI(LibSignalPni(uuid)) + } else { + null + } } } + /** Parse a byte array as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */ @JvmStatic - fun parseOrNull(raw: ByteArray?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } + fun parseOrNull(raw: ByteArray?): PNI? { + return if (raw == null) { + null + } else if (raw.size == 17) { + ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } + } else { + val uuid = UuidUtil.parseOrNull(raw) + if (uuid != null) { + PNI(LibSignalPni(uuid)) + } else { + null + } + } + } + /** Parses a string as a PNI, regardless if the `PNI:` prefix is present or not. Only use this if you are certain that what you're reading is a PNI. */ @JvmStatic @Throws(IllegalArgumentException::class) fun parseOrThrow(raw: String?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!") + /** Parse a byte array as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */ @JvmStatic @Throws(IllegalArgumentException::class) fun parseOrThrow(raw: ByteArray?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!") + /** Parse a byte string as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */ @JvmStatic @Throws(IllegalArgumentException::class) fun parseOrThrow(bytes: ByteString): PNI = parseOrThrow(bytes.toByteArray()) - /** Parses a plain UUID (without the `PNI:` prefix) as a PNI. Be certain that whatever you pass to this is for sure a PNI! */ - @JvmStatic - @Throws(IllegalArgumentException::class) - fun parseUnPrefixedOrThrow(raw: String?): PNI = parseUnPrefixedOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!") + /** Parses a string as a PNI, expecting that the value has a `PNI:` prefix. If it does not have the prefix (or is otherwise invalid), this will return null. */ + fun parsePrefixedOrNull(raw: String?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } } override fun toString(): String = super.toString() diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java index c886d25168..32b643ffb8 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java @@ -43,7 +43,7 @@ public final class SignalContactRecord implements SignalRecord { this.proto = proto; this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); this.aci = ACI.parseOrUnknown(proto.getAci()); - this.pni = OptionalUtil.absentIfEmpty(proto.getPni()).map(PNI::parseUnPrefixedOrNull); + this.pni = OptionalUtil.absentIfEmpty(proto.getPni()).map(PNI::parseOrNull); this.e164 = OptionalUtil.absentIfEmpty(proto.getE164()); this.profileGivenName = OptionalUtil.absentIfEmpty(proto.getGivenName()); this.profileFamilyName = OptionalUtil.absentIfEmpty(proto.getFamilyName()); diff --git a/libsignal/service/src/test/java/org/whispersystems/signalservice/push/PushTransportDetailsTest.java b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/PushTransportDetailsTest.java similarity index 83% rename from libsignal/service/src/test/java/org/whispersystems/signalservice/push/PushTransportDetailsTest.java rename to libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/PushTransportDetailsTest.java index 26cd13e26f..cd99985da2 100644 --- a/libsignal/service/src/test/java/org/whispersystems/signalservice/push/PushTransportDetailsTest.java +++ b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/PushTransportDetailsTest.java @@ -1,4 +1,9 @@ -package org.whispersystems.signalservice.push; +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.push; import junit.framework.TestCase; diff --git a/libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/ServiceIdTests.kt b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/ServiceIdTests.kt new file mode 100644 index 0000000000..a79a545e57 --- /dev/null +++ b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/push/ServiceIdTests.kt @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.push + +import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertNull +import org.junit.Test +import org.whispersystems.signalservice.api.push.ServiceId.ACI +import org.whispersystems.signalservice.api.push.ServiceId.PNI +import org.whispersystems.signalservice.api.util.UuidUtil +import java.util.UUID +import org.signal.libsignal.protocol.ServiceId.Aci as LibSignalAci +import org.signal.libsignal.protocol.ServiceId.Pni as LibSignalPni + +class ServiceIdTests { + + @Test + fun `ServiceId parseOrNull String`() { + val uuidString = UUID.randomUUID().toString() + + assertNull(ServiceId.parseOrNull(null as String?)) + assertNull(ServiceId.parseOrNull("")) + assertNull(ServiceId.parseOrNull("asdf")) + + assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(uuidString)) + assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull("PNI:$uuidString")) + } + + @Test + fun `ServiceId parseOrNull ByteArray`() { + val uuid = UUID.randomUUID() + val uuidString = uuid.toString() + val uuidBytes = UuidUtil.toByteArray(uuid) + + assertNull(ServiceId.parseOrNull(null as ByteArray?)) + assertNull(ServiceId.parseOrNull(ByteArray(0))) + assertNull(ServiceId.parseOrNull(byteArrayOf(1, 2, 3))) + + assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(uuidBytes)) + assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalPni(uuid).toServiceIdBinary())) + assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary())) + assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalPni(uuid).toServiceIdFixedWidthBinary())) + } + + @Test + fun `ACI parseOrNull String`() { + val uuid = UUID.randomUUID() + val uuidString = uuid.toString() + + assertNull(ACI.parseOrNull(null as String?)) + assertNull(ACI.parseOrNull("")) + assertNull(ACI.parseOrNull("asdf")) + assertNull(ACI.parseOrNull(LibSignalPni(uuid).toServiceIdString())) + + assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(uuidString)) + } + + @Test + fun `ACI parseOrNull ByteArray`() { + val uuid = UUID.randomUUID() + val uuidString = uuid.toString() + val uuidBytes = UuidUtil.toByteArray(uuid) + + assertNull(ACI.parseOrNull(null as ByteArray?)) + assertNull(ACI.parseOrNull(ByteArray(0))) + assertNull(ACI.parseOrNull(byteArrayOf(1, 2, 3))) + assertNull(ACI.parseOrNull(LibSignalPni(uuid).toServiceIdBinary())) + + assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(uuidBytes)) + assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(LibSignalAci(uuid).toServiceIdBinary())) + assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary())) + } + + @Test + fun `PNI parseOrNull String`() { + val uuidString = UUID.randomUUID().toString() + + assertNull(PNI.parseOrNull(null as String?)) + assertNull(PNI.parseOrNull("")) + assertNull(PNI.parseOrNull("asdf")) + + assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(uuidString)) + assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull("PNI:$uuidString")) + } + + @Test + fun `PNI parseOrNull ByteArray`() { + val uuid = UUID.randomUUID() + val uuidString = uuid.toString() + val uuidBytes = UuidUtil.toByteArray(uuid) + + assertNull(PNI.parseOrNull(null as ByteArray?)) + assertNull(PNI.parseOrNull(ByteArray(0))) + assertNull(PNI.parseOrNull(byteArrayOf(1, 2, 3))) + assertNull(PNI.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary())) + + assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(uuidBytes)) + assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(LibSignalPni(uuid).toServiceIdBinary())) + } + + @Test + fun `PNI parsePrefixedOrNull`() { + val uuidString = UUID.randomUUID().toString() + + assertNull(PNI.parsePrefixedOrNull(null)) + assertNull(PNI.parsePrefixedOrNull("")) + assertNull(PNI.parsePrefixedOrNull("asdf")) + assertNull(PNI.parsePrefixedOrNull(uuidString)) + + assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parsePrefixedOrNull("PNI:$uuidString")) + } +}