From bea204ab822ce442e654b09f23e0b352cef978bd Mon Sep 17 00:00:00 2001 From: Alex Hart Date: Thu, 18 Sep 2025 10:35:37 -0300 Subject: [PATCH] Convert GroupId to Kotlin. --- .../securesms/database/RecipientTable.kt | 2 +- .../securesms/groups/GroupId.java | 329 ----------------- .../thoughtcrime/securesms/groups/GroupId.kt | 330 ++++++++++++++++++ .../securesms/recipients/Recipient.kt | 8 +- .../colors/AvatarColorHashTest.kt | 2 +- 5 files changed, 336 insertions(+), 335 deletions(-) delete mode 100644 app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.java create mode 100644 app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.kt diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt index 2817a9beaf..d15656ad55 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientTable.kt @@ -3745,7 +3745,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da } if (blockedGroupIds.isNotEmpty()) { - val groupIds: List = blockedGroupIds.mapNotNull { raw -> + val groupIds: List = blockedGroupIds.filterNotNull().mapNotNull { raw -> try { GroupId.v1(raw) } catch (e: BadGroupIdException) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.java b/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.java deleted file mode 100644 index 90a72dfd0f..0000000000 --- a/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.java +++ /dev/null @@ -1,329 +0,0 @@ -package org.thoughtcrime.securesms.groups; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import org.signal.core.util.DatabaseId; -import org.signal.core.util.Hex; -import org.signal.libsignal.protocol.kdf.HKDF; -import org.signal.libsignal.zkgroup.InvalidInputException; -import org.signal.libsignal.zkgroup.groups.GroupIdentifier; -import org.signal.libsignal.zkgroup.groups.GroupMasterKey; -import org.signal.libsignal.zkgroup.groups.GroupSecretParams; -import org.thoughtcrime.securesms.util.LRUCache; -import org.thoughtcrime.securesms.util.Util; - -import java.io.IOException; -import java.security.SecureRandom; - -import okio.ByteString; - -public abstract class GroupId implements DatabaseId { - - private static final String ENCODED_SIGNAL_GROUP_V1_PREFIX = "__textsecure_group__!"; - private static final String ENCODED_SIGNAL_GROUP_V2_PREFIX = "__signal_group__v2__!"; - private static final String ENCODED_MMS_GROUP_PREFIX = "__signal_mms_group__!"; - private static final int MMS_BYTE_LENGTH = 16; - private static final int V1_MMS_BYTE_LENGTH = 16; - private static final int V1_BYTE_LENGTH = 16; - private static final int V2_BYTE_LENGTH = GroupIdentifier.SIZE; - - private final String encodedId; - - private static final LRUCache groupIdentifierCache = new LRUCache<>(1000); - - private GroupId(@NonNull String prefix, @NonNull byte[] bytes) { - this.encodedId = prefix + Hex.toStringCondensed(bytes); - } - - public static @NonNull GroupId.Mms mms(byte[] mmsGroupIdBytes) { - return new GroupId.Mms(mmsGroupIdBytes); - } - - public static @NonNull GroupId.V1 v1orThrow(byte[] gv1GroupIdBytes) { - try { - return v1(gv1GroupIdBytes); - } catch (BadGroupIdException e) { - throw new AssertionError(e); - } - } - - public static @NonNull GroupId.V1 v1(byte[] gv1GroupIdBytes) throws BadGroupIdException { - if (gv1GroupIdBytes.length != V1_BYTE_LENGTH) { - throw new BadGroupIdException(); - } - return new GroupId.V1(gv1GroupIdBytes); - } - - public static GroupId.V1 createV1(@NonNull SecureRandom secureRandom) { - return v1orThrow(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH)); - } - - public static GroupId.Mms createMms(@NonNull SecureRandom secureRandom) { - return mms(Util.getSecretBytes(secureRandom, MMS_BYTE_LENGTH)); - } - - /** - * Private because it's too easy to pass the {@link GroupMasterKey} bytes directly to this as they - * are the same length as the {@link GroupIdentifier}. - */ - private static GroupId.V2 v2(@NonNull byte[] bytes) throws BadGroupIdException { - if (bytes.length != V2_BYTE_LENGTH) { - throw new BadGroupIdException(); - } - return new GroupId.V2(bytes); - } - - public static GroupId.V2 v2(@NonNull GroupIdentifier groupIdentifier) { - try { - return v2(groupIdentifier.serialize()); - } catch (BadGroupIdException e) { - throw new AssertionError(e); - } - } - - public static GroupId.V2 v2(@NonNull GroupMasterKey masterKey) { - return v2(getIdentifierForMasterKey(masterKey)); - } - - public static GroupIdentifier getIdentifierForMasterKey(@NonNull GroupMasterKey masterKey) { - GroupIdentifier cachedIdentifier; - synchronized (groupIdentifierCache) { - cachedIdentifier = groupIdentifierCache.get(masterKey); - } - if (cachedIdentifier == null) { - cachedIdentifier = GroupSecretParams.deriveFromMasterKey(masterKey) - .getPublicParams() - .getGroupIdentifier(); - synchronized (groupIdentifierCache) { - groupIdentifierCache.put(masterKey, cachedIdentifier); - } - } - return cachedIdentifier; - } - - public static GroupId.Push push(ByteString bytes) throws BadGroupIdException { - return push(bytes.toByteArray()); - } - - public static GroupId.Push push(byte[] bytes) throws BadGroupIdException { - return bytes.length == V2_BYTE_LENGTH ? v2(bytes) : v1(bytes); - } - - public static GroupId.Push pushOrThrow(byte[] bytes) { - try { - return push(bytes); - } catch (BadGroupIdException e) { - throw new AssertionError(e); - } - } - - public static GroupId.Push pushOrNull(byte[] bytes) { - try { - return GroupId.push(bytes); - } catch (BadGroupIdException e) { - return null; - } - } - - public static @NonNull GroupId parseOrThrow(@NonNull String encodedGroupId) { - try { - return parse(encodedGroupId); - } catch (BadGroupIdException e) { - throw new AssertionError(e); - } - } - - public static @NonNull GroupId parse(@NonNull String encodedGroupId) throws BadGroupIdException { - try { - if (!isEncodedGroup(encodedGroupId)) { - throw new BadGroupIdException("Invalid encoding"); - } - - byte[] bytes = extractDecodedId(encodedGroupId); - - if (encodedGroupId.startsWith(ENCODED_SIGNAL_GROUP_V2_PREFIX)) return v2(bytes); - else if (encodedGroupId.startsWith(ENCODED_SIGNAL_GROUP_V1_PREFIX)) return v1(bytes); - else if (encodedGroupId.startsWith(ENCODED_MMS_GROUP_PREFIX)) return mms(bytes); - - throw new BadGroupIdException(); - } catch (IOException e) { - throw new BadGroupIdException(e); - } - } - - public static @Nullable GroupId parseNullable(@Nullable String encodedGroupId) throws BadGroupIdException { - if (encodedGroupId == null) { - return null; - } - - return parse(encodedGroupId); - } - - public static @Nullable GroupId parseNullableOrThrow(@Nullable String encodedGroupId) { - if (encodedGroupId == null) { - return null; - } - - return parseOrThrow(encodedGroupId); - } - - public static boolean isEncodedGroup(@NonNull String groupId) { - return groupId.startsWith(ENCODED_SIGNAL_GROUP_V2_PREFIX) || - groupId.startsWith(ENCODED_SIGNAL_GROUP_V1_PREFIX) || - groupId.startsWith(ENCODED_MMS_GROUP_PREFIX); - } - - private static byte[] extractDecodedId(@NonNull String encodedGroupId) throws IOException { - return Hex.fromStringCondensed(encodedGroupId.split("!", 2)[1]); - } - - public byte[] getDecodedId() { - try { - return extractDecodedId(encodedId); - } catch (IOException e) { - throw new AssertionError(e); - } - } - - @Override - public boolean equals(@Nullable Object obj) { - if (obj instanceof GroupId) { - return ((GroupId) obj).encodedId.equals(encodedId); - } - - return false; - } - - @Override - public int hashCode() { - return encodedId.hashCode(); - } - - @Override - public @NonNull String toString() { - return encodedId; - } - - @Override - public @NonNull String serialize() { - return encodedId; - } - - public abstract boolean isMms(); - - public abstract boolean isV1(); - - public abstract boolean isV2(); - - public abstract boolean isPush(); - - public GroupId.Mms requireMms() { - if (this instanceof GroupId.Mms) return (GroupId.Mms) this; - throw new AssertionError(); - } - - public GroupId.V1 requireV1() { - if (this instanceof GroupId.V1) return (GroupId.V1) this; - throw new AssertionError(); - } - - public GroupId.V2 requireV2() { - if (this instanceof GroupId.V2) return (GroupId.V2) this; - throw new AssertionError(); - } - - public GroupId.Push requirePush() { - if (this instanceof GroupId.Push) return (GroupId.Push) this; - throw new AssertionError(); - } - - public static final class Mms extends GroupId { - - private Mms(@NonNull byte[] bytes) { - super(ENCODED_MMS_GROUP_PREFIX, bytes); - } - - @Override - public boolean isMms() { - return true; - } - - @Override - public boolean isV1() { - return false; - } - - @Override - public boolean isV2() { - return false; - } - - @Override - public boolean isPush() { - return false; - } - } - - public static abstract class Push extends GroupId { - private Push(@NonNull String prefix, @NonNull byte[] bytes) { - super(prefix, bytes); - } - - @Override - public boolean isMms() { - return false; - } - - @Override - public boolean isPush() { - return true; - } - } - - public static final class V1 extends GroupId.Push { - - private V1(@NonNull byte[] bytes) { - super(ENCODED_SIGNAL_GROUP_V1_PREFIX, bytes); - } - - @Override - public boolean isV1() { - return true; - } - - @Override - public boolean isV2() { - return false; - } - - public GroupMasterKey deriveV2MigrationMasterKey() { - try { - return new GroupMasterKey(HKDF.deriveSecrets(getDecodedId(), "GV2 Migration".getBytes(), GroupMasterKey.SIZE)); - } catch (InvalidInputException e) { - throw new AssertionError(e); - } - } - - public GroupId.V2 deriveV2MigrationGroupId() { - return v2(deriveV2MigrationMasterKey()); - } - } - - public static final class V2 extends GroupId.Push { - - private V2(@NonNull byte[] bytes) { - super(ENCODED_SIGNAL_GROUP_V2_PREFIX, bytes); - } - - @Override - public boolean isV1() { - return false; - } - - @Override - public boolean isV2() { - return true; - } - } -} diff --git a/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.kt b/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.kt new file mode 100644 index 0000000000..ac20876d3d --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.kt @@ -0,0 +1,330 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.groups + +import android.os.Parcelable +import kotlinx.parcelize.IgnoredOnParcel +import kotlinx.parcelize.Parcelize +import kotlinx.serialization.Serializable +import kotlinx.serialization.Transient +import okio.ByteString +import org.signal.core.util.DatabaseId +import org.signal.core.util.Hex +import org.signal.libsignal.protocol.kdf.HKDF +import org.signal.libsignal.zkgroup.InvalidInputException +import org.signal.libsignal.zkgroup.groups.GroupIdentifier +import org.signal.libsignal.zkgroup.groups.GroupMasterKey +import org.signal.libsignal.zkgroup.groups.GroupSecretParams +import org.thoughtcrime.securesms.util.LRUCache +import org.thoughtcrime.securesms.util.Util +import java.io.IOException +import java.security.SecureRandom + +/** + * GroupId represents the identifier for a given group. + * + * We suppress CanBe Parameter because Parcelize wants the parameters to exist for reconstruction. + */ +@Suppress("CanBeParameter") +@Parcelize +@Serializable +sealed class GroupId(private val encodedId: String) : DatabaseId, Parcelable { + + companion object { + private const val ENCODED_SIGNAL_GROUP_V1_PREFIX = "__textsecure_group__!" + private const val ENCODED_SIGNAL_GROUP_V2_PREFIX = "__signal_group__v2__!" + private const val ENCODED_MMS_GROUP_PREFIX = "__signal_mms_group__!" + private const val MMS_BYTE_LENGTH = 16 + private const val V1_MMS_BYTE_LENGTH = 16 + private const val V1_BYTE_LENGTH = 16 + private const val V2_BYTE_LENGTH = GroupIdentifier.SIZE + + private val groupIdentifierCache: LRUCache = LRUCache(1000) + + @JvmStatic + fun mms(mmsGroupIdBytes: ByteArray): Mms = Mms(mmsGroupIdBytes) + + @JvmStatic + fun v1orThrow(gv1GroupIdBytes: ByteArray): V1 { + try { + return v1(gv1GroupIdBytes) + } catch (e: BadGroupIdException) { + throw AssertionError(e) + } + } + + @JvmStatic + @Throws(BadGroupIdException::class) + fun v1(gv1GroupIdBytes: ByteArray): V1 { + if (gv1GroupIdBytes.size != V1_BYTE_LENGTH) { + throw BadGroupIdException() + } + + return V1(gv1GroupIdBytes) + } + + @JvmStatic + fun createV1(secureRandom: SecureRandom): V1 = v1orThrow(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH)) + + @JvmStatic + fun createMms(secureRandom: SecureRandom): Mms = mms(Util.getSecretBytes(secureRandom, MMS_BYTE_LENGTH)) + + @Throws(BadGroupIdException::class) + private fun v2(bytes: ByteArray): V2 { + if (bytes.size != V2_BYTE_LENGTH) { + throw BadGroupIdException() + } + + return V2(bytes) + } + + @JvmStatic + fun v2(groupIdentifier: GroupIdentifier): V2 { + try { + return v2(groupIdentifier.serialize()) + } catch (e: BadGroupIdException) { + throw AssertionError(e) + } + } + + @JvmStatic + fun v2(masterKey: GroupMasterKey): V2 = v2(getIdentifierForMasterKey(masterKey)) + + @JvmStatic + fun getIdentifierForMasterKey(masterKey: GroupMasterKey): GroupIdentifier { + var cachedIdentifier: GroupIdentifier? = null + synchronized(groupIdentifierCache) { + cachedIdentifier = groupIdentifierCache.get(masterKey) + } + + if (cachedIdentifier == null) { + cachedIdentifier = GroupSecretParams.deriveFromMasterKey(masterKey) + .publicParams + .groupIdentifier + + synchronized(groupIdentifierCache) { + groupIdentifierCache.put(masterKey, cachedIdentifier) + } + } + + return cachedIdentifier + } + + @JvmStatic + @Throws(BadGroupIdException::class) + fun push(bytes: ByteString): Push { + return push(bytes.toByteArray()) + } + + @JvmStatic + @Throws(BadGroupIdException::class) + fun push(bytes: ByteArray): Push { + return if (bytes.size == V2_BYTE_LENGTH) v2(bytes) else v1(bytes) + } + + @JvmStatic + fun pushOrThrow(bytes: ByteArray): Push { + try { + return push(bytes) + } catch (e: BadGroupIdException) { + throw AssertionError(e) + } + } + + @JvmStatic + fun pushOrNull(bytes: ByteArray): Push? { + return try { + push(bytes) + } catch (e: BadGroupIdException) { + null + } + } + + @JvmStatic + fun parseOrThrow(encodedGroupId: String): GroupId { + try { + return parse(encodedGroupId) + } catch (e: BadGroupIdException) { + throw AssertionError(e) + } + } + + @JvmStatic + @Throws(BadGroupIdException::class) + fun parse(encodedGroupId: String): GroupId { + try { + if (!isEncodedGroup(encodedGroupId)) { + throw BadGroupIdException("Invalid encoding") + } + + val bytes = extractDecodedId(encodedGroupId) + + when { + encodedGroupId.startsWith(ENCODED_SIGNAL_GROUP_V2_PREFIX) -> return v2(bytes) + encodedGroupId.startsWith(ENCODED_SIGNAL_GROUP_V1_PREFIX) -> return v1(bytes) + encodedGroupId.startsWith(ENCODED_MMS_GROUP_PREFIX) -> return mms(bytes) + } + + throw BadGroupIdException() + } catch (e: IOException) { + throw BadGroupIdException(e) + } + } + + @JvmStatic + @Throws(BadGroupIdException::class) + fun parseNullable(encodedGroupId: String?): GroupId? { + if (encodedGroupId == null) { + return null + } + + return parse(encodedGroupId) + } + + @JvmStatic + fun parseNullableOrThrow(encodedGroupId: String?): GroupId? { + if (encodedGroupId == null) { + return null + } + + return parseOrThrow(encodedGroupId) + } + + @JvmStatic + fun isEncodedGroup(groupId: String): Boolean { + return groupId.startsWith(ENCODED_SIGNAL_GROUP_V2_PREFIX) || + groupId.startsWith(ENCODED_SIGNAL_GROUP_V1_PREFIX) || + groupId.startsWith(ENCODED_MMS_GROUP_PREFIX) + } + + @Throws(IOException::class) + private fun extractDecodedId(encodedGroupId: String): ByteArray { + return Hex.fromStringCondensed(encodedGroupId.split("!".toPattern(), 2)[1]) + } + } + + constructor(prefix: String, bytes: ByteArray) : this(prefix + Hex.toStringCondensed(bytes)) + + val decodedId: ByteArray get() { + try { + return extractDecodedId(encodedId) + } catch (e: IOException) { + throw AssertionError(e) + } + } + + override fun toString(): String { + return encodedId + } + + override fun serialize(): String { + return encodedId + } + + abstract val isMms: Boolean + abstract val isV1: Boolean + abstract val isV2: Boolean + abstract val isPush: Boolean + + fun requireMms(): Mms { + assert(this is Mms) + return this as Mms + } + + fun requireV1(): V1 { + assert(this is V1) + return this as V1 + } + + fun requireV2(): V2 { + assert(this is V2) + return this as V2 + } + + fun requirePush(): Push { + assert(this is Push) + return this as Push + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as GroupId + + return encodedId == other.encodedId + } + + override fun hashCode(): Int { + return encodedId.hashCode() + } + + @Serializable + class Mms(private val mmsBytes: ByteArray) : GroupId(ENCODED_MMS_GROUP_PREFIX, mmsBytes) { + @Transient + @IgnoredOnParcel + override val isMms: Boolean = true + + @Transient + @IgnoredOnParcel + override val isV1: Boolean = false + + @Transient + @IgnoredOnParcel + override val isV2: Boolean = false + + @Transient + @IgnoredOnParcel + override val isPush: Boolean = false + } + + @Serializable + sealed class Push(private val prefix: String, open val pushBytes: ByteArray) : GroupId(prefix, pushBytes) { + @Transient + @IgnoredOnParcel + override val isMms: Boolean = false + + @Transient + @IgnoredOnParcel + override val isPush: Boolean = true + + @Transient + @IgnoredOnParcel + override val isV1: Boolean = false + + @Transient + @IgnoredOnParcel + override val isV2: Boolean = false + } + + @Parcelize + @Serializable + class V1(private val v1Bytes: ByteArray) : Push(ENCODED_SIGNAL_GROUP_V1_PREFIX, v1Bytes) { + @Transient + @IgnoredOnParcel + override val isV1: Boolean = true + + fun deriveV2MigrationMasterKey(): GroupMasterKey { + try { + return GroupMasterKey(HKDF.deriveSecrets(decodedId, "GV2 Migration".toByteArray(), GroupMasterKey.SIZE)) + } catch (e: InvalidInputException) { + throw AssertionError(e) + } + } + + fun deriveV2MigrationGroupId(): V2 { + return v2(deriveV2MigrationMasterKey()) + } + } + + @Parcelize + @Serializable + class V2(private val v2Bytes: ByteArray) : Push(ENCODED_SIGNAL_GROUP_V2_PREFIX, v2Bytes) { + @Transient + @IgnoredOnParcel + override val isV2: Boolean = true + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/recipients/Recipient.kt b/app/src/main/java/org/thoughtcrime/securesms/recipients/Recipient.kt index 0628a41546..c260b70c2c 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/recipients/Recipient.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/recipients/Recipient.kt @@ -197,28 +197,28 @@ class Recipient( val isMmsGroup: Boolean get() { val groupId = resolved.groupIdValue - return groupId != null && groupId.isMms() + return groupId != null && groupId.isMms } /** Whether the recipient represents a Signal group. */ val isPushGroup: Boolean get() { val groupId = resolved.groupIdValue - return groupId != null && groupId.isPush() + return groupId != null && groupId.isPush } /** Whether the recipient represents a V1 Signal group. These types of groups were deprecated in 2020. */ val isPushV1Group: Boolean get() { val groupId = resolved.groupIdValue - return groupId != null && groupId.isV1() + return groupId != null && groupId.isV1 } /** Whether the recipient represents a V2 Signal group. */ val isPushV2Group: Boolean get() { val groupId = resolved.groupIdValue - return groupId != null && groupId.isV2() + return groupId != null && groupId.isV2 } /** Whether the recipient represents a distribution list (a specific list of people to send a story to). */ diff --git a/app/src/test/java/org/thoughtcrime/securesms/conversation/colors/AvatarColorHashTest.kt b/app/src/test/java/org/thoughtcrime/securesms/conversation/colors/AvatarColorHashTest.kt index 9ad29fa179..3484213e04 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/conversation/colors/AvatarColorHashTest.kt +++ b/app/src/test/java/org/thoughtcrime/securesms/conversation/colors/AvatarColorHashTest.kt @@ -31,6 +31,6 @@ class AvatarColorHashTest { @Test fun `hash test vector - GroupId`() { - assertEquals(AvatarColor.A130, AvatarColorHash.forGroupId(GroupId.V2.push(Base64.decode("BwJRIdomqOSOckHjnJsknNCibCZKJFt+RxLIpa9CWJ4=")))) + assertEquals(AvatarColor.A130, AvatarColorHash.forGroupId(GroupId.push(Base64.decode("BwJRIdomqOSOckHjnJsknNCibCZKJFt+RxLIpa9CWJ4=")))) } }