Improve ServiceId parsing functions.

This commit is contained in:
Greyson Parrelli
2023-08-09 14:54:06 -04:00
committed by Alex Hart
parent 784f94ecdb
commit 6d2d3ae528
7 changed files with 162 additions and 20 deletions

View File

@@ -4154,7 +4154,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
return RecipientRecord( return RecipientRecord(
id = recipientId, id = recipientId,
aci = ACI.parseOrNull(cursor.requireString(ACI_COLUMN)), aci = ACI.parseOrNull(cursor.requireString(ACI_COLUMN)),
pni = PNI.parseOrNull(cursor.requireString(PNI_COLUMN)), pni = PNI.parsePrefixedOrNull(cursor.requireString(PNI_COLUMN)),
username = cursor.requireString(USERNAME), username = cursor.requireString(USERNAME),
e164 = cursor.requireString(E164), e164 = cursor.requireString(E164),
email = cursor.requireString(EMAIL), email = cursor.requireString(EMAIL),

View File

@@ -47,7 +47,7 @@ public class PniMigrationJob extends MigrationJob {
return; return;
} }
PNI pni = PNI.parseUnPrefixedOrNull(ApplicationDependencies.getSignalServiceAccountManager().getWhoAmI().getPni()); PNI pni = PNI.parseOrNull(ApplicationDependencies.getSignalServiceAccountManager().getWhoAmI().getPni());
if (pni == null) { if (pni == null) {
throw new IOException("Invalid PNI!"); throw new IOException("Invalid PNI!");

View File

@@ -129,7 +129,7 @@ public final class RegistrationRepository {
Preconditions.checkNotNull(response.getPniPreKeyCollection(), "Missing PNI prekey collection!"); Preconditions.checkNotNull(response.getPniPreKeyCollection(), "Missing PNI prekey collection!");
ACI aci = ACI.parseOrThrow(response.getVerifyAccountResponse().getUuid()); ACI aci = ACI.parseOrThrow(response.getVerifyAccountResponse().getUuid());
PNI pni = PNI.parseUnPrefixedOrThrow(response.getVerifyAccountResponse().getPni()); PNI pni = PNI.parseOrThrow(response.getVerifyAccountResponse().getPni());
boolean hasPin = response.getVerifyAccountResponse().isStorageCapable(); boolean hasPin = response.getVerifyAccountResponse().isStorageCapable();
SignalStore.account().setAci(aci); SignalStore.account().setAci(aci);

View File

@@ -54,7 +54,11 @@ sealed class ServiceId(val libSignalServiceId: LibSignalServiceId) {
} }
return try { return try {
fromLibSignal(LibSignalServiceId.parseFromBinary(raw)) if (raw.size == 17) {
fromLibSignal(LibSignalServiceId.parseFromFixedWidthBinary(raw))
} else {
fromLibSignal(LibSignalServiceId.parseFromBinary(raw))
}
} catch (e: IllegalArgumentException) { } catch (e: IllegalArgumentException) {
null null
} catch (e: InvalidServiceIdException) { } catch (e: InvalidServiceIdException) {
@@ -152,39 +156,57 @@ sealed class ServiceId(val libSignalServiceId: LibSignalServiceId) {
@JvmStatic @JvmStatic
fun from(uuid: UUID): PNI = PNI(LibSignalPni(uuid)) fun from(uuid: UUID): PNI = PNI(LibSignalPni(uuid))
/** Parses a string as a PNI, regardless if the `PNI:` prefix is present or not. Only use this if you are certain that what you're reading is a PNI. */
@JvmStatic @JvmStatic
fun parseOrNull(raw: String?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } fun parseOrNull(raw: String?): PNI? {
return if (raw == null) {
/** Parses a plain UUID (without the `PNI:` prefix) as a PNI. Be certain that whatever you pass to this is for sure a PNI! */
@JvmStatic
fun parseUnPrefixedOrNull(raw: String?): PNI? {
val uuid = UuidUtil.parseOrNull(raw)
return if (uuid != null) {
PNI(LibSignalPni(uuid))
} else {
null null
} else if (raw.startsWith("PNI:")) {
return parsePrefixedOrNull(raw)
} else {
val uuid = UuidUtil.parseOrNull(raw)
if (uuid != null) {
PNI(LibSignalPni(uuid))
} else {
null
}
} }
} }
/** Parse a byte array as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */
@JvmStatic @JvmStatic
fun parseOrNull(raw: ByteArray?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null } fun parseOrNull(raw: ByteArray?): PNI? {
return if (raw == null) {
null
} else if (raw.size == 17) {
ServiceId.parseOrNull(raw).let { if (it is PNI) it else null }
} else {
val uuid = UuidUtil.parseOrNull(raw)
if (uuid != null) {
PNI(LibSignalPni(uuid))
} else {
null
}
}
}
/** Parses a string as a PNI, regardless if the `PNI:` prefix is present or not. Only use this if you are certain that what you're reading is a PNI. */
@JvmStatic @JvmStatic
@Throws(IllegalArgumentException::class) @Throws(IllegalArgumentException::class)
fun parseOrThrow(raw: String?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!") fun parseOrThrow(raw: String?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!")
/** Parse a byte array as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */
@JvmStatic @JvmStatic
@Throws(IllegalArgumentException::class) @Throws(IllegalArgumentException::class)
fun parseOrThrow(raw: ByteArray?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!") fun parseOrThrow(raw: ByteArray?): PNI = parseOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!")
/** Parse a byte string as a PNI, regardless if it has the type prefix byte present or not. Only use this if you are certain what you're reading is a PNI. */
@JvmStatic @JvmStatic
@Throws(IllegalArgumentException::class) @Throws(IllegalArgumentException::class)
fun parseOrThrow(bytes: ByteString): PNI = parseOrThrow(bytes.toByteArray()) fun parseOrThrow(bytes: ByteString): PNI = parseOrThrow(bytes.toByteArray())
/** Parses a plain UUID (without the `PNI:` prefix) as a PNI. Be certain that whatever you pass to this is for sure a PNI! */ /** Parses a string as a PNI, expecting that the value has a `PNI:` prefix. If it does not have the prefix (or is otherwise invalid), this will return null. */
@JvmStatic fun parsePrefixedOrNull(raw: String?): PNI? = ServiceId.parseOrNull(raw).let { if (it is PNI) it else null }
@Throws(IllegalArgumentException::class)
fun parseUnPrefixedOrThrow(raw: String?): PNI = parseUnPrefixedOrNull(raw) ?: throw IllegalArgumentException("Invalid PNI!")
} }
override fun toString(): String = super.toString() override fun toString(): String = super.toString()

View File

@@ -43,7 +43,7 @@ public final class SignalContactRecord implements SignalRecord {
this.proto = proto; this.proto = proto;
this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto);
this.aci = ACI.parseOrUnknown(proto.getAci()); this.aci = ACI.parseOrUnknown(proto.getAci());
this.pni = OptionalUtil.absentIfEmpty(proto.getPni()).map(PNI::parseUnPrefixedOrNull); this.pni = OptionalUtil.absentIfEmpty(proto.getPni()).map(PNI::parseOrNull);
this.e164 = OptionalUtil.absentIfEmpty(proto.getE164()); this.e164 = OptionalUtil.absentIfEmpty(proto.getE164());
this.profileGivenName = OptionalUtil.absentIfEmpty(proto.getGivenName()); this.profileGivenName = OptionalUtil.absentIfEmpty(proto.getGivenName());
this.profileFamilyName = OptionalUtil.absentIfEmpty(proto.getFamilyName()); this.profileFamilyName = OptionalUtil.absentIfEmpty(proto.getFamilyName());

View File

@@ -1,4 +1,9 @@
package org.whispersystems.signalservice.push; /*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.push;
import junit.framework.TestCase; import junit.framework.TestCase;

View File

@@ -0,0 +1,115 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.push
import junit.framework.TestCase.assertEquals
import junit.framework.TestCase.assertNull
import org.junit.Test
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import org.whispersystems.signalservice.api.util.UuidUtil
import java.util.UUID
import org.signal.libsignal.protocol.ServiceId.Aci as LibSignalAci
import org.signal.libsignal.protocol.ServiceId.Pni as LibSignalPni
class ServiceIdTests {
@Test
fun `ServiceId parseOrNull String`() {
val uuidString = UUID.randomUUID().toString()
assertNull(ServiceId.parseOrNull(null as String?))
assertNull(ServiceId.parseOrNull(""))
assertNull(ServiceId.parseOrNull("asdf"))
assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(uuidString))
assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull("PNI:$uuidString"))
}
@Test
fun `ServiceId parseOrNull ByteArray`() {
val uuid = UUID.randomUUID()
val uuidString = uuid.toString()
val uuidBytes = UuidUtil.toByteArray(uuid)
assertNull(ServiceId.parseOrNull(null as ByteArray?))
assertNull(ServiceId.parseOrNull(ByteArray(0)))
assertNull(ServiceId.parseOrNull(byteArrayOf(1, 2, 3)))
assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(uuidBytes))
assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalPni(uuid).toServiceIdBinary()))
assertEquals(ACI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary()))
assertEquals(PNI.from(UUID.fromString(uuidString)), ServiceId.parseOrNull(LibSignalPni(uuid).toServiceIdFixedWidthBinary()))
}
@Test
fun `ACI parseOrNull String`() {
val uuid = UUID.randomUUID()
val uuidString = uuid.toString()
assertNull(ACI.parseOrNull(null as String?))
assertNull(ACI.parseOrNull(""))
assertNull(ACI.parseOrNull("asdf"))
assertNull(ACI.parseOrNull(LibSignalPni(uuid).toServiceIdString()))
assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(uuidString))
}
@Test
fun `ACI parseOrNull ByteArray`() {
val uuid = UUID.randomUUID()
val uuidString = uuid.toString()
val uuidBytes = UuidUtil.toByteArray(uuid)
assertNull(ACI.parseOrNull(null as ByteArray?))
assertNull(ACI.parseOrNull(ByteArray(0)))
assertNull(ACI.parseOrNull(byteArrayOf(1, 2, 3)))
assertNull(ACI.parseOrNull(LibSignalPni(uuid).toServiceIdBinary()))
assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(uuidBytes))
assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(LibSignalAci(uuid).toServiceIdBinary()))
assertEquals(ACI.from(UUID.fromString(uuidString)), ACI.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary()))
}
@Test
fun `PNI parseOrNull String`() {
val uuidString = UUID.randomUUID().toString()
assertNull(PNI.parseOrNull(null as String?))
assertNull(PNI.parseOrNull(""))
assertNull(PNI.parseOrNull("asdf"))
assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(uuidString))
assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull("PNI:$uuidString"))
}
@Test
fun `PNI parseOrNull ByteArray`() {
val uuid = UUID.randomUUID()
val uuidString = uuid.toString()
val uuidBytes = UuidUtil.toByteArray(uuid)
assertNull(PNI.parseOrNull(null as ByteArray?))
assertNull(PNI.parseOrNull(ByteArray(0)))
assertNull(PNI.parseOrNull(byteArrayOf(1, 2, 3)))
assertNull(PNI.parseOrNull(LibSignalAci(uuid).toServiceIdFixedWidthBinary()))
assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(uuidBytes))
assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parseOrNull(LibSignalPni(uuid).toServiceIdBinary()))
}
@Test
fun `PNI parsePrefixedOrNull`() {
val uuidString = UUID.randomUUID().toString()
assertNull(PNI.parsePrefixedOrNull(null))
assertNull(PNI.parsePrefixedOrNull(""))
assertNull(PNI.parsePrefixedOrNull("asdf"))
assertNull(PNI.parsePrefixedOrNull(uuidString))
assertEquals(PNI.from(UUID.fromString(uuidString)), PNI.parsePrefixedOrNull("PNI:$uuidString"))
}
}