The rest of the storage service unwrapping.

This commit is contained in:
Greyson Parrelli
2024-11-13 10:35:02 -05:00
parent 8746f483c0
commit 7dd1fc09c0
33 changed files with 754 additions and 1692 deletions

View File

@@ -27,6 +27,7 @@ import org.thoughtcrime.securesms.storage.StorageRecordUpdate
import org.thoughtcrime.securesms.storage.StorageSyncHelper import org.thoughtcrime.securesms.storage.StorageSyncHelper
import org.whispersystems.signalservice.api.push.DistributionId import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord
import org.whispersystems.signalservice.api.storage.recipientServiceAddresses
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import java.util.UUID import java.util.UUID
@@ -552,7 +553,7 @@ class DistributionListTables constructor(context: Context?, databaseHelper: Sign
} }
fun getRecipientIdForSyncRecord(record: SignalStoryDistributionListRecord): RecipientId? { fun getRecipientIdForSyncRecord(record: SignalStoryDistributionListRecord): RecipientId? {
val uuid: UUID = requireNotNull(UuidUtil.parseOrNull(record.identifier)) { "Incoming record did not have a valid identifier." } val uuid: UUID = requireNotNull(UuidUtil.parseOrNull(record.proto.identifier)) { "Incoming record did not have a valid identifier." }
val distributionId = DistributionId.from(uuid) val distributionId = DistributionId.from(uuid)
return readableDatabase.query( return readableDatabase.query(
@@ -591,30 +592,30 @@ class DistributionListTables constructor(context: Context?, databaseHelper: Sign
} }
fun applyStorageSyncStoryDistributionListInsert(insert: SignalStoryDistributionListRecord) { fun applyStorageSyncStoryDistributionListInsert(insert: SignalStoryDistributionListRecord) {
val distributionId = DistributionId.from(UuidUtil.parseOrThrow(insert.identifier)) val distributionId = DistributionId.from(UuidUtil.parseOrThrow(insert.proto.identifier))
if (distributionId == DistributionId.MY_STORY) { if (distributionId == DistributionId.MY_STORY) {
throw AssertionError("Should never try to insert My Story") throw AssertionError("Should never try to insert My Story")
} }
val privacyMode: DistributionListPrivacyMode = when { val privacyMode: DistributionListPrivacyMode = when {
insert.isBlockList && insert.recipients.isEmpty() -> DistributionListPrivacyMode.ALL insert.proto.isBlockList && insert.proto.recipientServiceIds.isEmpty() -> DistributionListPrivacyMode.ALL
insert.isBlockList -> DistributionListPrivacyMode.ALL_EXCEPT insert.proto.isBlockList -> DistributionListPrivacyMode.ALL_EXCEPT
else -> DistributionListPrivacyMode.ONLY_WITH else -> DistributionListPrivacyMode.ONLY_WITH
} }
createList( createList(
name = insert.name, name = insert.proto.name,
members = insert.recipients.map(RecipientId::from), members = insert.proto.recipientServiceAddresses.map(RecipientId::from),
distributionId = distributionId, distributionId = distributionId,
allowsReplies = insert.allowsReplies(), allowsReplies = insert.proto.allowsReplies,
deletionTimestamp = insert.deletedAtTimestamp, deletionTimestamp = insert.proto.deletedAtTimestamp,
privacyMode = privacyMode, privacyMode = privacyMode,
storageId = insert.id.raw storageId = insert.id.raw
) )
} }
fun applyStorageSyncStoryDistributionListUpdate(update: StorageRecordUpdate<SignalStoryDistributionListRecord>) { fun applyStorageSyncStoryDistributionListUpdate(update: StorageRecordUpdate<SignalStoryDistributionListRecord>) {
val distributionId = DistributionId.from(UuidUtil.parseOrThrow(update.new.identifier)) val distributionId = DistributionId.from(UuidUtil.parseOrThrow(update.new.proto.identifier))
val distributionListId: DistributionListId? = readableDatabase.query(ListTable.TABLE_NAME, arrayOf(ListTable.ID), "${ListTable.DISTRIBUTION_ID} = ?", SqlUtil.buildArgs(distributionId.toString()), null, null, null).use { cursor -> val distributionListId: DistributionListId? = readableDatabase.query(ListTable.TABLE_NAME, arrayOf(ListTable.ID), "${ListTable.DISTRIBUTION_ID} = ?", SqlUtil.buildArgs(distributionId.toString()), null, null, null).use { cursor ->
if (cursor == null || !cursor.moveToFirst()) { if (cursor == null || !cursor.moveToFirst()) {
@@ -632,26 +633,26 @@ class DistributionListTables constructor(context: Context?, databaseHelper: Sign
val recipientId = getRecipientId(distributionListId)!! val recipientId = getRecipientId(distributionListId)!!
SignalDatabase.recipients.updateStorageId(recipientId, update.new.id.raw) SignalDatabase.recipients.updateStorageId(recipientId, update.new.id.raw)
if (update.new.deletedAtTimestamp > 0L) { if (update.new.proto.deletedAtTimestamp > 0L) {
if (distributionId == DistributionId.MY_STORY) { if (distributionId == DistributionId.MY_STORY) {
Log.w(TAG, "Refusing to delete My Story.") Log.w(TAG, "Refusing to delete My Story.")
return return
} }
deleteList(distributionListId, update.new.deletedAtTimestamp) deleteList(distributionListId, update.new.proto.deletedAtTimestamp)
return return
} }
val privacyMode: DistributionListPrivacyMode = when { val privacyMode: DistributionListPrivacyMode = when {
update.new.isBlockList && update.new.recipients.isEmpty() -> DistributionListPrivacyMode.ALL update.new.proto.isBlockList && update.new.proto.recipientServiceIds.isEmpty() -> DistributionListPrivacyMode.ALL
update.new.isBlockList -> DistributionListPrivacyMode.ALL_EXCEPT update.new.proto.isBlockList -> DistributionListPrivacyMode.ALL_EXCEPT
else -> DistributionListPrivacyMode.ONLY_WITH else -> DistributionListPrivacyMode.ONLY_WITH
} }
writableDatabase.withinTransaction { writableDatabase.withinTransaction {
val listTableValues = contentValuesOf( val listTableValues = contentValuesOf(
ListTable.ALLOWS_REPLIES to update.new.allowsReplies(), ListTable.ALLOWS_REPLIES to update.new.proto.allowsReplies,
ListTable.NAME to update.new.name, ListTable.NAME to update.new.proto.name,
ListTable.IS_UNKNOWN to false, ListTable.IS_UNKNOWN to false,
ListTable.PRIVACY_MODE to privacyMode.serialize() ListTable.PRIVACY_MODE to privacyMode.serialize()
) )
@@ -664,7 +665,7 @@ class DistributionListTables constructor(context: Context?, databaseHelper: Sign
) )
val currentlyInDistributionList = getRawMembers(distributionListId, privacyMode).toSet() val currentlyInDistributionList = getRawMembers(distributionListId, privacyMode).toSet()
val shouldBeInDistributionList = update.new.recipients.map(RecipientId::from).toSet() val shouldBeInDistributionList = update.new.proto.recipientServiceAddresses.map(RecipientId::from).toSet()
val toRemove = currentlyInDistributionList - shouldBeInDistributionList val toRemove = currentlyInDistributionList - shouldBeInDistributionList
val toAdd = shouldBeInDistributionList - currentlyInDistributionList val toAdd = shouldBeInDistributionList - currentlyInDistributionList

View File

@@ -24,7 +24,6 @@ import org.signal.core.util.nullIfBlank
import org.signal.core.util.nullIfEmpty import org.signal.core.util.nullIfEmpty
import org.signal.core.util.optionalString import org.signal.core.util.optionalString
import org.signal.core.util.or import org.signal.core.util.or
import org.signal.core.util.orNull
import org.signal.core.util.readToList import org.signal.core.util.readToList
import org.signal.core.util.readToSet import org.signal.core.util.readToSet
import org.signal.core.util.readToSingleBoolean import org.signal.core.util.readToSingleBoolean
@@ -43,6 +42,7 @@ import org.signal.core.util.updateAll
import org.signal.core.util.withinTransaction import org.signal.core.util.withinTransaction
import org.signal.libsignal.protocol.IdentityKey import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.InvalidKeyException import org.signal.libsignal.protocol.InvalidKeyException
import org.signal.libsignal.zkgroup.groups.GroupMasterKey
import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredential import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredential
import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.signal.storageservice.protos.groups.local.DecryptedGroup import org.signal.storageservice.protos.groups.local.DecryptedGroup
@@ -113,12 +113,13 @@ import org.whispersystems.signalservice.api.storage.SignalContactRecord
import org.whispersystems.signalservice.api.storage.SignalGroupV1Record import org.whispersystems.signalservice.api.storage.SignalGroupV1Record
import org.whispersystems.signalservice.api.storage.SignalGroupV2Record import org.whispersystems.signalservice.api.storage.SignalGroupV2Record
import org.whispersystems.signalservice.api.storage.StorageId import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.signalAci
import org.whispersystems.signalservice.api.storage.signalPni
import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record
import java.io.Closeable import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.util.Collections import java.util.Collections
import java.util.LinkedList import java.util.LinkedList
import java.util.Objects
import java.util.Optional import java.util.Optional
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.jvm.optionals.getOrNull import kotlin.jvm.optionals.getOrNull
@@ -861,7 +862,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
val recipientId: RecipientId val recipientId: RecipientId
if (id < 0) { if (id < 0) {
Log.w(TAG, "[applyStorageSyncContactInsert] Failed to insert. Possibly merging.") Log.w(TAG, "[applyStorageSyncContactInsert] Failed to insert. Possibly merging.")
recipientId = getAndPossiblyMerge(aci = insert.aci.orNull(), pni = insert.pni.orNull(), e164 = insert.number.orNull(), pniVerified = insert.isPniSignatureVerified) recipientId = getAndPossiblyMerge(aci = ACI.parseOrNull(insert.proto.aci), pni = PNI.parseOrNull(insert.proto.pni), e164 = insert.proto.e164.nullIfBlank(), pniVerified = insert.proto.pniSignatureVerified)
resolvePotentialUsernameConflicts(values.getAsString(USERNAME), recipientId) resolvePotentialUsernameConflicts(values.getAsString(USERNAME), recipientId)
db.update(TABLE_NAME, values, ID_WHERE, SqlUtil.buildArgs(recipientId)) db.update(TABLE_NAME, values, ID_WHERE, SqlUtil.buildArgs(recipientId))
@@ -869,18 +870,18 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
recipientId = RecipientId.from(id) recipientId = RecipientId.from(id)
} }
if (insert.identityKey.isPresent && (insert.aci.isPresent || insert.pni.isPresent)) { if (insert.proto.identityKey.isNotEmpty() && (insert.proto.signalAci != null || insert.proto.signalPni != null)) {
try { try {
val serviceId: ServiceId = insert.aci.orNull() ?: insert.pni.get() val serviceId: ServiceId = insert.proto.signalAci ?: insert.proto.signalPni!!
val identityKey = IdentityKey(insert.identityKey.get(), 0) val identityKey = IdentityKey(insert.proto.identityKey.toByteArray(), 0)
identities.updateIdentityAfterSync(serviceId.toString(), recipientId, identityKey, StorageSyncModels.remoteToLocalIdentityStatus(insert.identityState)) identities.updateIdentityAfterSync(serviceId.toString(), recipientId, identityKey, StorageSyncModels.remoteToLocalIdentityStatus(insert.proto.identityState))
} catch (e: InvalidKeyException) { } catch (e: InvalidKeyException) {
Log.w(TAG, "Failed to process identity key during insert! Skipping.", e) Log.w(TAG, "Failed to process identity key during insert! Skipping.", e)
} }
} }
updateExtras(recipientId) { updateExtras(recipientId) {
it.hideStory(insert.shouldHideStory()) it.hideStory(insert.proto.hideStory)
} }
threadDatabase.applyStorageSyncUpdate(recipientId, insert) threadDatabase.applyStorageSyncUpdate(recipientId, insert)
@@ -901,7 +902,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
var recipientId = getByColumn(STORAGE_SERVICE_ID, Base64.encodeWithPadding(update.old.id.raw)).get() var recipientId = getByColumn(STORAGE_SERVICE_ID, Base64.encodeWithPadding(update.old.id.raw)).get()
Log.w(TAG, "[applyStorageSyncContactUpdate] Found user $recipientId. Possibly merging.") Log.w(TAG, "[applyStorageSyncContactUpdate] Found user $recipientId. Possibly merging.")
recipientId = getAndPossiblyMerge(aci = update.new.aci.orElse(null), pni = update.new.pni.orElse(null), e164 = update.new.number.orElse(null), pniVerified = update.new.isPniSignatureVerified) recipientId = getAndPossiblyMerge(aci = ACI.parseOrNull(update.new.proto.aci), pni = PNI.parseOrNull(update.new.proto.pni), e164 = update.new.proto.e164.nullIfBlank(), pniVerified = update.new.proto.pniSignatureVerified)
Log.w(TAG, "[applyStorageSyncContactUpdate] Merged into $recipientId") Log.w(TAG, "[applyStorageSyncContactUpdate] Merged into $recipientId")
resolvePotentialUsernameConflicts(values.getAsString(USERNAME), recipientId) resolvePotentialUsernameConflicts(values.getAsString(USERNAME), recipientId)
@@ -919,9 +920,9 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
try { try {
val oldIdentityRecord = identityStore.getIdentityRecord(recipientId) val oldIdentityRecord = identityStore.getIdentityRecord(recipientId)
if (update.new.identityKey.isPresent && update.new.aci.isPresent) { if (update.new.proto.identityKey.isNotEmpty() && update.new.proto.signalAci != null) {
val identityKey = IdentityKey(update.new.identityKey.get(), 0) val identityKey = IdentityKey(update.new.proto.identityKey.toByteArray(), 0)
identities.updateIdentityAfterSync(update.new.aci.get().toString(), recipientId, identityKey, StorageSyncModels.remoteToLocalIdentityStatus(update.new.identityState)) identities.updateIdentityAfterSync(update.new.proto.aci, recipientId, identityKey, StorageSyncModels.remoteToLocalIdentityStatus(update.new.proto.identityState))
} }
val newIdentityRecord = identityStore.getIdentityRecord(recipientId) val newIdentityRecord = identityStore.getIdentityRecord(recipientId)
@@ -935,7 +936,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
} }
updateExtras(recipientId) { updateExtras(recipientId) {
it.hideStory(update.new.shouldHideStory()) it.hideStory(update.new.proto.hideStory)
} }
threads.applyStorageSyncUpdate(recipientId, update.new) threads.applyStorageSyncUpdate(recipientId, update.new)
@@ -968,13 +969,13 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
throw AssertionError("Had an update, but it didn't match any rows!") throw AssertionError("Had an update, but it didn't match any rows!")
} }
val recipient = Recipient.externalGroupExact(GroupId.v1orThrow(update.old.groupId)) val recipient = Recipient.externalGroupExact(GroupId.v1orThrow(update.old.proto.id.toByteArray()))
threads.applyStorageSyncUpdate(recipient.id, update.new) threads.applyStorageSyncUpdate(recipient.id, update.new)
recipient.live().refresh() recipient.live().refresh()
} }
fun applyStorageSyncGroupV2Insert(insert: SignalGroupV2Record) { fun applyStorageSyncGroupV2Insert(insert: SignalGroupV2Record) {
val masterKey = insert.masterKeyOrThrow val masterKey = GroupMasterKey(insert.proto.masterKey.toByteArray())
val groupId = GroupId.v2(masterKey) val groupId = GroupId.v2(masterKey)
val values = getValuesForStorageGroupV2(insert, true) val values = getValuesForStorageGroupV2(insert, true)
@@ -991,12 +992,12 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
Log.w(TAG, "Unable to create restore placeholder for $groupId, group already exists") Log.w(TAG, "Unable to create restore placeholder for $groupId, group already exists")
} }
groups.setShowAsStoryState(groupId, insert.storySendMode.toShowAsStoryState()) groups.setShowAsStoryState(groupId, insert.proto.storySendMode.toShowAsStoryState())
val recipient = Recipient.externalGroupExact(groupId) val recipient = Recipient.externalGroupExact(groupId)
updateExtras(recipient.id) { updateExtras(recipient.id) {
it.hideStory(insert.shouldHideStory()) it.hideStory(insert.proto.hideStory)
} }
Log.i(TAG, "Scheduling request for latest group info for $groupId") Log.i(TAG, "Scheduling request for latest group info for $groupId")
@@ -1013,15 +1014,15 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
throw AssertionError("Had an update, but it didn't match any rows!") throw AssertionError("Had an update, but it didn't match any rows!")
} }
val masterKey = update.old.masterKeyOrThrow val masterKey = GroupMasterKey(update.old.proto.masterKey.toByteArray())
val groupId = GroupId.v2(masterKey) val groupId = GroupId.v2(masterKey)
val recipient = Recipient.externalGroupExact(groupId) val recipient = Recipient.externalGroupExact(groupId)
updateExtras(recipient.id) { updateExtras(recipient.id) {
it.hideStory(update.new.shouldHideStory()) it.hideStory(update.new.proto.hideStory)
} }
groups.setShowAsStoryState(groupId, update.new.storySendMode.toShowAsStoryState()) groups.setShowAsStoryState(groupId, update.new.proto.storySendMode.toShowAsStoryState())
threads.applyStorageSyncUpdate(recipient.id, update.new) threads.applyStorageSyncUpdate(recipient.id, update.new)
recipient.live().refresh() recipient.live().refresh()
} }
@@ -1051,7 +1052,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(update.new.id.raw)) put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(update.new.id.raw))
if (update.new.proto.hasUnknownFields()) { if (update.new.proto.hasUnknownFields()) {
put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(update.new.serializeUnknownFields()!!)) put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(update.new.serializedUnknowns!!))
} else { } else {
putNull(STORAGE_SERVICE_PROTO) putNull(STORAGE_SERVICE_PROTO)
} }
@@ -4160,68 +4161,68 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
private fun getValuesForStorageContact(contact: SignalContactRecord, isInsert: Boolean): ContentValues { private fun getValuesForStorageContact(contact: SignalContactRecord, isInsert: Boolean): ContentValues {
return ContentValues().apply { return ContentValues().apply {
val profileName = ProfileName.fromParts(contact.profileGivenName.orElse(null), contact.profileFamilyName.orElse(null)) val profileName = ProfileName.fromParts(contact.proto.givenName.nullIfBlank(), contact.proto.familyName.nullIfBlank())
val systemName = ProfileName.fromParts(contact.systemGivenName.orElse(null), contact.systemFamilyName.orElse(null)) val systemName = ProfileName.fromParts(contact.proto.systemGivenName.nullIfBlank(), contact.proto.systemFamilyName.nullIfBlank())
val username = contact.username.orElse(null) val username = contact.proto.username.nullIfBlank()
val nickname = ProfileName.fromParts(contact.nicknameGivenName.orNull(), contact.nicknameFamilyName.orNull()) val nickname = ProfileName.fromParts(contact.proto.nickname?.given, contact.proto.nickname?.family)
put(ACI_COLUMN, contact.aci.orElse(null)?.toString()) put(ACI_COLUMN, contact.proto.signalAci?.toString())
put(PNI_COLUMN, contact.pni.orElse(null)?.toString()) put(PNI_COLUMN, contact.proto.signalPni?.toString())
put(E164, contact.number.orElse(null)) put(E164, contact.proto.e164.nullIfBlank())
put(PROFILE_GIVEN_NAME, profileName.givenName) put(PROFILE_GIVEN_NAME, profileName.givenName)
put(PROFILE_FAMILY_NAME, profileName.familyName) put(PROFILE_FAMILY_NAME, profileName.familyName)
put(PROFILE_JOINED_NAME, profileName.toString()) put(PROFILE_JOINED_NAME, profileName.toString())
put(SYSTEM_GIVEN_NAME, systemName.givenName) put(SYSTEM_GIVEN_NAME, systemName.givenName)
put(SYSTEM_FAMILY_NAME, systemName.familyName) put(SYSTEM_FAMILY_NAME, systemName.familyName)
put(SYSTEM_JOINED_NAME, systemName.toString()) put(SYSTEM_JOINED_NAME, systemName.toString())
put(SYSTEM_NICKNAME, contact.systemNickname.orElse(null)) put(SYSTEM_NICKNAME, contact.proto.systemNickname.nullIfBlank())
put(PROFILE_KEY, contact.profileKey.map { source -> Base64.encodeWithPadding(source) }.orElse(null)) put(PROFILE_KEY, contact.proto.profileKey.takeIf { it.isNotEmpty() }?.let { source -> Base64.encodeWithPadding(source.toByteArray()) })
put(USERNAME, if (TextUtils.isEmpty(username)) null else username) put(USERNAME, if (TextUtils.isEmpty(username)) null else username)
put(PROFILE_SHARING, if (contact.isProfileSharingEnabled) "1" else "0") put(PROFILE_SHARING, contact.proto.whitelisted.toInt())
put(BLOCKED, if (contact.isBlocked) "1" else "0") put(BLOCKED, contact.proto.blocked.toInt())
put(MUTE_UNTIL, contact.muteUntil) put(MUTE_UNTIL, contact.proto.mutedUntilTimestamp)
put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(contact.id.raw)) put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(contact.id.raw))
put(HIDDEN, contact.isHidden) put(HIDDEN, contact.proto.hidden)
put(PNI_SIGNATURE_VERIFIED, contact.isPniSignatureVerified.toInt()) put(PNI_SIGNATURE_VERIFIED, contact.proto.pniSignatureVerified.toInt())
put(NICKNAME_GIVEN_NAME, nickname.givenName.nullIfBlank()) put(NICKNAME_GIVEN_NAME, nickname.givenName.nullIfBlank())
put(NICKNAME_FAMILY_NAME, nickname.familyName.nullIfBlank()) put(NICKNAME_FAMILY_NAME, nickname.familyName.nullIfBlank())
put(NICKNAME_JOINED_NAME, nickname.toString().nullIfBlank()) put(NICKNAME_JOINED_NAME, nickname.toString().nullIfBlank())
put(NOTE, contact.note.orNull().nullIfBlank()) put(NOTE, contact.proto.note.nullIfBlank())
if (contact.hasUnknownFields()) { if (contact.proto.hasUnknownFields()) {
put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(Objects.requireNonNull(contact.serializeUnknownFields()))) put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(contact.serializedUnknowns!!))
} else { } else {
putNull(STORAGE_SERVICE_PROTO) putNull(STORAGE_SERVICE_PROTO)
} }
put(UNREGISTERED_TIMESTAMP, contact.unregisteredTimestamp) put(UNREGISTERED_TIMESTAMP, contact.proto.unregisteredAtTimestamp)
if (contact.unregisteredTimestamp > 0L) { if (contact.proto.unregisteredAtTimestamp > 0L) {
put(REGISTERED, RegisteredState.NOT_REGISTERED.id) put(REGISTERED, RegisteredState.NOT_REGISTERED.id)
} else if (contact.aci.isPresent) { } else if (contact.proto.signalAci != null) {
put(REGISTERED, RegisteredState.REGISTERED.id) put(REGISTERED, RegisteredState.REGISTERED.id)
} else { } else {
Log.w(TAG, "Contact is marked as registered, but has no serviceId! Can't locally mark registered. (Phone: ${contact.number.orElse("null")}, Username: ${username?.isNotEmpty()})") Log.w(TAG, "Contact is marked as registered, but has no serviceId! Can't locally mark registered. (Phone: ${contact.proto.e164.nullIfBlank()}, Username: ${username?.isNotEmpty()})")
} }
if (isInsert) { if (isInsert) {
put(AVATAR_COLOR, AvatarColorHash.forAddress(contact.aci.map { it.toString() }.or(contact.pni.map { it.toString() }).orNull(), contact.number.orNull()).serialize()) put(AVATAR_COLOR, AvatarColorHash.forAddress(contact.proto.signalAci?.toString() ?: contact.proto.signalPni?.toString(), contact.proto.e164).serialize())
} }
} }
} }
private fun getValuesForStorageGroupV1(groupV1: SignalGroupV1Record, isInsert: Boolean): ContentValues { private fun getValuesForStorageGroupV1(groupV1: SignalGroupV1Record, isInsert: Boolean): ContentValues {
return ContentValues().apply { return ContentValues().apply {
val groupId = GroupId.v1orThrow(groupV1.groupId) val groupId = GroupId.v1orThrow(groupV1.proto.id.toByteArray())
put(GROUP_ID, groupId.toString()) put(GROUP_ID, groupId.toString())
put(TYPE, RecipientType.GV1.id) put(TYPE, RecipientType.GV1.id)
put(PROFILE_SHARING, if (groupV1.isProfileSharingEnabled) "1" else "0") put(PROFILE_SHARING, if (groupV1.proto.whitelisted) "1" else "0")
put(BLOCKED, if (groupV1.isBlocked) "1" else "0") put(BLOCKED, if (groupV1.proto.blocked) "1" else "0")
put(MUTE_UNTIL, groupV1.muteUntil) put(MUTE_UNTIL, groupV1.proto.mutedUntilTimestamp)
put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(groupV1.id.raw)) put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(groupV1.id.raw))
if (groupV1.hasUnknownFields()) { if (groupV1.proto.hasUnknownFields()) {
put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(groupV1.serializeUnknownFields())) put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(groupV1.serializedUnknowns!!))
} else { } else {
putNull(STORAGE_SERVICE_PROTO) putNull(STORAGE_SERVICE_PROTO)
} }
@@ -4234,18 +4235,18 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
private fun getValuesForStorageGroupV2(groupV2: SignalGroupV2Record, isInsert: Boolean): ContentValues { private fun getValuesForStorageGroupV2(groupV2: SignalGroupV2Record, isInsert: Boolean): ContentValues {
return ContentValues().apply { return ContentValues().apply {
val groupId = GroupId.v2(groupV2.masterKeyOrThrow) val groupId = GroupId.v2(GroupMasterKey(groupV2.proto.masterKey.toByteArray()))
put(GROUP_ID, groupId.toString()) put(GROUP_ID, groupId.toString())
put(TYPE, RecipientType.GV2.id) put(TYPE, RecipientType.GV2.id)
put(PROFILE_SHARING, if (groupV2.isProfileSharingEnabled) "1" else "0") put(PROFILE_SHARING, if (groupV2.proto.whitelisted) "1" else "0")
put(BLOCKED, if (groupV2.isBlocked) "1" else "0") put(BLOCKED, if (groupV2.proto.blocked) "1" else "0")
put(MUTE_UNTIL, groupV2.muteUntil) put(MUTE_UNTIL, groupV2.proto.mutedUntilTimestamp)
put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(groupV2.id.raw)) put(STORAGE_SERVICE_ID, Base64.encodeWithPadding(groupV2.id.raw))
put(MENTION_SETTING, if (groupV2.notifyForMentionsWhenMuted()) MentionSetting.ALWAYS_NOTIFY.id else MentionSetting.DO_NOT_NOTIFY.id) put(MENTION_SETTING, if (groupV2.proto.dontNotifyForMentionsIfMuted) MentionSetting.DO_NOT_NOTIFY.id else MentionSetting.ALWAYS_NOTIFY.id)
if (groupV2.hasUnknownFields()) { if (groupV2.proto.hasUnknownFields()) {
put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(groupV2.serializeUnknownFields())) put(STORAGE_SERVICE_PROTO, Base64.encodeWithPadding(groupV2.serializedUnknowns!!))
} else { } else {
putNull(STORAGE_SERVICE_PROTO) putNull(STORAGE_SERVICE_PROTO)
} }

View File

@@ -1510,15 +1510,15 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
} }
fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalContactRecord) { fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalContactRecord) {
applyStorageSyncUpdate(recipientId, record.isArchived, record.isForcedUnread) applyStorageSyncUpdate(recipientId, record.proto.archived, record.proto.markedUnread)
} }
fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalGroupV1Record) { fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalGroupV1Record) {
applyStorageSyncUpdate(recipientId, record.isArchived, record.isForcedUnread) applyStorageSyncUpdate(recipientId, record.proto.archived, record.proto.markedUnread)
} }
fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalGroupV2Record) { fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalGroupV2Record) {
applyStorageSyncUpdate(recipientId, record.isArchived, record.isForcedUnread) applyStorageSyncUpdate(recipientId, record.proto.archived, record.proto.markedUnread)
} }
fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalAccountRecord) { fun applyStorageSyncUpdate(recipientId: RecipientId, record: SignalAccountRecord) {

View File

@@ -104,7 +104,7 @@ class AccountRecordProcessor(
remote.proto.storyViewReceiptsEnabled remote.proto.storyViewReceiptsEnabled
} }
val unknownFields = remote.serializeUnknownFields() val unknownFields = remote.serializedUnknowns
val merged = SignalAccountRecord.newBuilder(unknownFields).apply { val merged = SignalAccountRecord.newBuilder(unknownFields).apply {
givenName = mergedGivenName givenName = mergedGivenName
@@ -162,8 +162,4 @@ class AccountRecordProcessor(
override fun compare(lhs: SignalAccountRecord, rhs: SignalAccountRecord): Int { override fun compare(lhs: SignalAccountRecord, rhs: SignalAccountRecord): Int {
return 0 return 0
} }
private fun doParamsMatch(base: SignalAccountRecord, test: SignalAccountRecord): Boolean {
return base.serializeUnknownFields().contentEquals(test.serializeUnknownFields()) && base.proto == test.proto
}
} }

View File

@@ -5,21 +5,30 @@
package org.thoughtcrime.securesms.storage package org.thoughtcrime.securesms.storage
import okio.ByteString.Companion.toByteString
import org.signal.core.util.isNotEmpty
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.toOptional
import org.signal.ringrtc.CallLinkRootKey import org.signal.ringrtc.CallLinkRootKey
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.service.webrtc.links.CallLinkRoomId import org.thoughtcrime.securesms.service.webrtc.links.CallLinkRoomId
import org.whispersystems.signalservice.api.storage.SignalCallLinkRecord import org.whispersystems.signalservice.api.storage.SignalCallLinkRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.toSignalCallLinkRecord
import java.util.Optional import java.util.Optional
internal class CallLinkRecordProcessor : DefaultStorageRecordProcessor<SignalCallLinkRecord>() { /**
* Record processor for [SignalCallLinkRecord].
* Handles merging and updating our local store when processing remote call link storage records.
*/
class CallLinkRecordProcessor : DefaultStorageRecordProcessor<SignalCallLinkRecord>() {
companion object { companion object {
private val TAG = Log.tag(CallLinkRecordProcessor::class) private val TAG = Log.tag(CallLinkRecordProcessor::class)
} }
override fun compare(o1: SignalCallLinkRecord?, o2: SignalCallLinkRecord?): Int { override fun compare(o1: SignalCallLinkRecord?, o2: SignalCallLinkRecord?): Int {
return if (o1?.rootKey.contentEquals(o2?.rootKey)) { return if (o1?.proto?.rootKey == o2?.proto?.rootKey) {
0 0
} else { } else {
1 1
@@ -27,21 +36,21 @@ internal class CallLinkRecordProcessor : DefaultStorageRecordProcessor<SignalCal
} }
override fun isInvalid(remote: SignalCallLinkRecord): Boolean { override fun isInvalid(remote: SignalCallLinkRecord): Boolean {
return remote.adminPassKey.isNotEmpty() && remote.deletionTimestamp > 0L return remote.proto.adminPasskey.isNotEmpty() && remote.proto.deletedAtTimestampMs > 0L
} }
override fun getMatching(remote: SignalCallLinkRecord, keyGenerator: StorageKeyGenerator): Optional<SignalCallLinkRecord> { override fun getMatching(remote: SignalCallLinkRecord, keyGenerator: StorageKeyGenerator): Optional<SignalCallLinkRecord> {
Log.d(TAG, "Attempting to get matching record...") Log.d(TAG, "Attempting to get matching record...")
val rootKey = CallLinkRootKey(remote.rootKey) val callRootKey = CallLinkRootKey(remote.proto.rootKey.toByteArray())
val roomId = CallLinkRoomId.fromCallLinkRootKey(rootKey) val roomId = CallLinkRoomId.fromCallLinkRootKey(callRootKey)
val callLink = SignalDatabase.callLinks.getCallLinkByRoomId(roomId) val callLink = SignalDatabase.callLinks.getCallLinkByRoomId(roomId)
if (callLink != null && callLink.credentials?.adminPassBytes != null) { if (callLink != null && callLink.credentials?.adminPassBytes != null) {
val builder = SignalCallLinkRecord.Builder(keyGenerator.generate(), null).apply { return SignalCallLinkRecord.newBuilder(null).apply {
setRootKey(rootKey.keyBytes) rootKey = callRootKey.keyBytes.toByteString()
setAdminPassKey(callLink.credentials.adminPassBytes) adminPasskey = callLink.credentials.adminPassBytes.toByteString()
setDeletedTimestamp(callLink.deletionTimestamp) deletedAtTimestampMs = callLink.deletionTimestamp
} }.build().toSignalCallLinkRecord(StorageId.forCallLink(keyGenerator.generate())).toOptional()
return Optional.of(builder.build())
} else { } else {
return Optional.empty<SignalCallLinkRecord>() return Optional.empty<SignalCallLinkRecord>()
} }
@@ -53,15 +62,15 @@ internal class CallLinkRecordProcessor : DefaultStorageRecordProcessor<SignalCal
* Other fields should not change, except for the clearing of the admin passkey on deletion * Other fields should not change, except for the clearing of the admin passkey on deletion
*/ */
override fun merge(remote: SignalCallLinkRecord, local: SignalCallLinkRecord, keyGenerator: StorageKeyGenerator): SignalCallLinkRecord { override fun merge(remote: SignalCallLinkRecord, local: SignalCallLinkRecord, keyGenerator: StorageKeyGenerator): SignalCallLinkRecord {
return if (remote.isDeleted() && local.isDeleted()) { return if (remote.proto.deletedAtTimestampMs > 0 && local.proto.deletedAtTimestampMs > 0) {
if (remote.deletionTimestamp < local.deletionTimestamp) { if (remote.proto.deletedAtTimestampMs < local.proto.deletedAtTimestampMs) {
remote remote
} else { } else {
local local
} }
} else if (remote.isDeleted()) { } else if (remote.proto.deletedAtTimestampMs > 0) {
remote remote
} else if (local.isDeleted()) { } else if (local.proto.deletedAtTimestampMs > 0) {
local local
} else { } else {
remote remote
@@ -77,12 +86,12 @@ internal class CallLinkRecordProcessor : DefaultStorageRecordProcessor<SignalCal
} }
private fun insertOrUpdateRecord(record: SignalCallLinkRecord) { private fun insertOrUpdateRecord(record: SignalCallLinkRecord) {
val rootKey = CallLinkRootKey(record.rootKey) val rootKey = CallLinkRootKey(record.proto.rootKey.toByteArray())
SignalDatabase.callLinks.insertOrUpdateCallLinkByRootKey( SignalDatabase.callLinks.insertOrUpdateCallLinkByRootKey(
callLinkRootKey = rootKey, callLinkRootKey = rootKey,
adminPassKey = record.adminPassKey, adminPassKey = record.proto.adminPasskey.toByteArray(),
deletionTimestamp = record.deletionTimestamp, deletionTimestamp = record.proto.deletedAtTimestampMs,
storageId = record.id storageId = record.id
) )
} }

View File

@@ -1,340 +0,0 @@
package org.thoughtcrime.securesms.storage;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import org.signal.core.util.StringUtil;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.database.RecipientTable;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.thoughtcrime.securesms.database.model.RecipientRecord;
import org.thoughtcrime.securesms.jobs.RetrieveProfileJob;
import org.thoughtcrime.securesms.keyvalue.SignalStore;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
import org.whispersystems.signalservice.api.storage.SignalContactRecord;
import org.whispersystems.signalservice.api.util.OptionalUtil;
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeSet;
import java.util.regex.Pattern;
public class ContactRecordProcessor extends DefaultStorageRecordProcessor<SignalContactRecord> {
private static final String TAG = Log.tag(ContactRecordProcessor.class);
private static final Pattern E164_PATTERN = Pattern.compile("^\\+[1-9]\\d{0,18}$");
private final RecipientTable recipientTable;
private final ACI selfAci;
private final PNI selfPni;
private final String selfE164;
public ContactRecordProcessor() {
this(SignalStore.account().getAci(),
SignalStore.account().getPni(),
SignalStore.account().getE164(),
SignalDatabase.recipients());
}
ContactRecordProcessor(@Nullable ACI selfAci, @Nullable PNI selfPni, @Nullable String selfE164, @NonNull RecipientTable recipientTable) {
this.recipientTable = recipientTable;
this.selfAci = selfAci;
this.selfPni = selfPni;
this.selfE164 = selfE164;
}
/**
* For contact records specifically, we have some extra work that needs to be done before we process all of the records.
*
* We have to find all unregistered ACI-only records and split them into two separate contact rows locally, if necessary.
* The reasons are nuanced, but the TL;DR is that we want to split unregistered users into separate rows so that a user
* could re-register and get a different ACI.
*/
@Override
public void process(@NonNull Collection<? extends SignalContactRecord> remoteRecords, @NonNull StorageKeyGenerator keyGenerator) throws IOException {
List<SignalContactRecord> unregisteredAciOnly = new ArrayList<>();
for (SignalContactRecord remoteRecord : remoteRecords) {
if (isInvalid(remoteRecord)) {
continue;
}
if (remoteRecord.getUnregisteredTimestamp() > 0 && remoteRecord.getAci().isPresent() && remoteRecord.getPni().isEmpty() && remoteRecord.getNumber().isEmpty()) {
unregisteredAciOnly.add(remoteRecord);
}
}
if (unregisteredAciOnly.size() > 0) {
for (SignalContactRecord aciOnly : unregisteredAciOnly) {
SignalDatabase.recipients().splitForStorageSyncIfNecessary(aciOnly.getAci().get());
}
}
super.process(remoteRecords, keyGenerator);
}
/**
* Error cases:
* - You can't have a contact record without an ACI or PNI.
* - You can't have a contact record for yourself. That should be an account record.
*
* Note: This method could be written more succinctly, but the logs are useful :)
*/
@Override
public boolean isInvalid(@NonNull SignalContactRecord remote) {
boolean hasAci = remote.getAci().isPresent() && remote.getAci().get().isValid();
boolean hasPni = remote.getPni().isPresent() && remote.getPni().get().isValid();
if (!hasAci && !hasPni) {
Log.w(TAG, "Found a ContactRecord with neither an ACI nor a PNI -- marking as invalid.");
return true;
} else if (selfAci != null && selfAci.equals(remote.getAci().orElse(null)) ||
(selfPni != null && selfPni.equals(remote.getPni().orElse(null))) ||
(selfE164 != null && remote.getNumber().isPresent() && remote.getNumber().get().equals(selfE164)))
{
Log.w(TAG, "Found a ContactRecord for ourselves -- marking as invalid.");
return true;
} else if (remote.getNumber().isPresent() && !isValidE164(remote.getNumber().get())) {
Log.w(TAG, "Found a record with an invalid E164. Marking as invalid.");
return true;
} else {
return false;
}
}
@Override
public @NonNull Optional<SignalContactRecord> getMatching(@NonNull SignalContactRecord remote, @NonNull StorageKeyGenerator keyGenerator) {
Optional<RecipientId> found = remote.getAci().isPresent() ? recipientTable.getByAci(remote.getAci().get()) : Optional.empty();
if (found.isEmpty() && remote.getNumber().isPresent()) {
found = recipientTable.getByE164(remote.getNumber().get());
}
if (found.isEmpty() && remote.getPni().isPresent()) {
found = recipientTable.getByPni(remote.getPni().get());
}
return found.map(recipientTable::getRecordForSync)
.map(settings -> {
if (settings.getStorageId() != null) {
return StorageSyncModels.localToRemoteRecord(settings);
} else {
Log.w(TAG, "Newly discovering a registered user via storage service. Saving a storageId for them.");
recipientTable.updateStorageId(settings.getId(), keyGenerator.generate());
RecipientRecord updatedSettings = Objects.requireNonNull(recipientTable.getRecordForSync(settings.getId()));
return StorageSyncModels.localToRemoteRecord(updatedSettings);
}
})
.map(r -> new SignalContactRecord(r.getId(), r.getProto().contact));
}
@Override
public @NonNull SignalContactRecord merge(@NonNull SignalContactRecord remote, @NonNull SignalContactRecord local, @NonNull StorageKeyGenerator keyGenerator) {
String profileGivenName;
String profileFamilyName;
if (remote.getProfileGivenName().isPresent() || remote.getProfileFamilyName().isPresent()) {
profileGivenName = remote.getProfileGivenName().orElse("");
profileFamilyName = remote.getProfileFamilyName().orElse("");
} else {
profileGivenName = local.getProfileGivenName().orElse("");
profileFamilyName = local.getProfileFamilyName().orElse("");
}
IdentityState identityState;
byte[] identityKey;
if ((remote.getIdentityState() != local.getIdentityState() && remote.getIdentityKey().isPresent()) ||
(remote.getIdentityKey().isPresent() && local.getIdentityKey().isEmpty()) ||
(remote.getIdentityKey().isPresent() && local.getUnregisteredTimestamp() > 0))
{
identityState = remote.getIdentityState();
identityKey = remote.getIdentityKey().get();
} else {
identityState = local.getIdentityState();
identityKey = local.getIdentityKey().orElse(null);
}
if (local.getAci().isPresent() && identityKey != null && remote.getIdentityKey().isPresent() && !Arrays.equals(identityKey, remote.getIdentityKey().get())) {
Log.w(TAG, "The local and remote identity keys do not match for " + local.getAci().orElse(null) + ". Enqueueing a profile fetch.");
RetrieveProfileJob.enqueue(Recipient.trustedPush(local.getAci().get(), local.getPni().orElse(null), local.getNumber().orElse(null)).getId());
}
PNI pni;
String e164;
boolean e164sMatchButPnisDont = local.getNumber().isPresent() &&
local.getNumber().get().equals(remote.getNumber().orElse(null)) &&
local.getPni().isPresent() &&
remote.getPni().isPresent() &&
!local.getPni().get().equals(remote.getPni().get());
boolean pnisMatchButE164sDont = local.getPni().isPresent() &&
local.getPni().get().equals(remote.getPni().orElse(null)) &&
local.getNumber().isPresent() &&
remote.getNumber().isPresent() &&
!local.getNumber().get().equals(remote.getNumber().get());
if (e164sMatchButPnisDont) {
Log.w(TAG, "Matching E164s, but the PNIs differ! Trusting our local pair.");
// TODO [pnp] Schedule CDS fetch?
pni = local.getPni().get();
e164 = local.getNumber().get();
} else if (pnisMatchButE164sDont) {
Log.w(TAG, "Matching PNIs, but the E164s differ! Trusting our local pair.");
// TODO [pnp] Schedule CDS fetch?
pni = local.getPni().get();
e164 = local.getNumber().get();
} else {
pni = OptionalUtil.or(remote.getPni(), local.getPni()).orElse(null);
e164 = OptionalUtil.or(remote.getNumber(), local.getNumber()).orElse(null);
}
byte[] unknownFields = remote.serializeUnknownFields();
ACI aci = local.getAci().isEmpty() ? remote.getAci().orElse(null) : local.getAci().get();
byte[] profileKey = OptionalUtil.or(remote.getProfileKey(), local.getProfileKey()).orElse(null);
String username = OptionalUtil.or(remote.getUsername(), local.getUsername()).orElse("");
boolean blocked = remote.isBlocked();
boolean profileSharing = remote.isProfileSharingEnabled();
boolean archived = remote.isArchived();
boolean forcedUnread = remote.isForcedUnread();
long muteUntil = remote.getMuteUntil();
boolean hideStory = remote.shouldHideStory();
long unregisteredTimestamp = remote.getUnregisteredTimestamp();
boolean hidden = remote.isHidden();
String systemGivenName = SignalStore.account().isPrimaryDevice() ? local.getSystemGivenName().orElse("") : remote.getSystemGivenName().orElse("");
String systemFamilyName = SignalStore.account().isPrimaryDevice() ? local.getSystemFamilyName().orElse("") : remote.getSystemFamilyName().orElse("");
String systemNickname = remote.getSystemNickname().orElse("");
String nicknameGivenName = remote.getNicknameGivenName().orElse("");
String nicknameFamilyName = remote.getNicknameFamilyName().orElse("");
boolean pniSignatureVerified = remote.isPniSignatureVerified() || local.isPniSignatureVerified();
String note = remote.getNote().or(local::getNote).orElse("");
boolean matchesRemote = doParamsMatch(remote, unknownFields, aci, pni, e164, profileGivenName, profileFamilyName, systemGivenName, systemFamilyName, systemNickname, profileKey, username, identityState, identityKey, blocked, profileSharing, archived, forcedUnread, muteUntil, hideStory, unregisteredTimestamp, hidden, pniSignatureVerified, nicknameGivenName, nicknameFamilyName, note);
boolean matchesLocal = doParamsMatch(local, unknownFields, aci, pni, e164, profileGivenName, profileFamilyName, systemGivenName, systemFamilyName, systemNickname, profileKey, username, identityState, identityKey, blocked, profileSharing, archived, forcedUnread, muteUntil, hideStory, unregisteredTimestamp, hidden, pniSignatureVerified, nicknameGivenName, nicknameFamilyName, note);
if (matchesRemote) {
return remote;
} else if (matchesLocal) {
return local;
} else {
return new SignalContactRecord.Builder(keyGenerator.generate(), aci, unknownFields)
.setE164(e164)
.setPni(pni)
.setProfileGivenName(profileGivenName)
.setProfileFamilyName(profileFamilyName)
.setSystemGivenName(systemGivenName)
.setSystemFamilyName(systemFamilyName)
.setSystemNickname(systemNickname)
.setProfileKey(profileKey)
.setUsername(username)
.setIdentityState(identityState)
.setIdentityKey(identityKey)
.setBlocked(blocked)
.setProfileSharingEnabled(profileSharing)
.setArchived(archived)
.setForcedUnread(forcedUnread)
.setMuteUntil(muteUntil)
.setHideStory(hideStory)
.setUnregisteredTimestamp(unregisteredTimestamp)
.setHidden(hidden)
.setPniSignatureVerified(pniSignatureVerified)
.setNicknameGivenName(nicknameGivenName)
.setNicknameFamilyName(nicknameFamilyName)
.setNote(note)
.build();
}
}
@Override
public void insertLocal(@NonNull SignalContactRecord record) {
recipientTable.applyStorageSyncContactInsert(record);
}
@Override
public void updateLocal(@NonNull StorageRecordUpdate<SignalContactRecord> update) {
recipientTable.applyStorageSyncContactUpdate(update);
}
@Override
public int compare(@NonNull SignalContactRecord lhs, @NonNull SignalContactRecord rhs) {
if ((lhs.getAci().isPresent() && Objects.equals(lhs.getAci(), rhs.getAci())) ||
(lhs.getNumber().isPresent() && Objects.equals(lhs.getNumber(), rhs.getNumber())) ||
(lhs.getPni().isPresent() && Objects.equals(lhs.getPni(), rhs.getPni())))
{
return 0;
} else {
return 1;
}
}
private static boolean isValidE164(String value) {
return E164_PATTERN.matcher(value).matches();
}
private static boolean doParamsMatch(@NonNull SignalContactRecord contact,
@Nullable byte[] unknownFields,
@Nullable ACI aci,
@Nullable PNI pni,
@Nullable String e164,
@NonNull String profileGivenName,
@NonNull String profileFamilyName,
@NonNull String systemGivenName,
@NonNull String systemFamilyName,
@NonNull String systemNickname,
@Nullable byte[] profileKey,
@NonNull String username,
@Nullable IdentityState identityState,
@Nullable byte[] identityKey,
boolean blocked,
boolean profileSharing,
boolean archived,
boolean forcedUnread,
long muteUntil,
boolean hideStory,
long unregisteredTimestamp,
boolean hidden,
boolean pniSignatureVerified,
@NonNull String nicknameGivenName,
@NonNull String nicknameFamilyName,
@NonNull String note)
{
return Arrays.equals(contact.serializeUnknownFields(), unknownFields) &&
Objects.equals(contact.getAci().orElse(null), aci) &&
Objects.equals(contact.getPni().orElse(null), pni) &&
Objects.equals(contact.getNumber().orElse(null), e164) &&
Objects.equals(contact.getProfileGivenName().orElse(""), profileGivenName) &&
Objects.equals(contact.getProfileFamilyName().orElse(""), profileFamilyName) &&
Objects.equals(contact.getSystemGivenName().orElse(""), systemGivenName) &&
Objects.equals(contact.getSystemFamilyName().orElse(""), systemFamilyName) &&
Objects.equals(contact.getSystemNickname().orElse(""), systemNickname) &&
Arrays.equals(contact.getProfileKey().orElse(null), profileKey) &&
Objects.equals(contact.getUsername().orElse(""), username) &&
Objects.equals(contact.getIdentityState(), identityState) &&
Arrays.equals(contact.getIdentityKey().orElse(null), identityKey) &&
contact.isBlocked() == blocked &&
contact.isProfileSharingEnabled() == profileSharing &&
contact.isArchived() == archived &&
contact.isForcedUnread() == forcedUnread &&
contact.getMuteUntil() == muteUntil &&
contact.shouldHideStory() == hideStory &&
contact.getUnregisteredTimestamp() == unregisteredTimestamp &&
contact.isHidden() == hidden &&
contact.isPniSignatureVerified() == pniSignatureVerified &&
Objects.equals(contact.getNicknameGivenName().orElse(""), nicknameGivenName) &&
Objects.equals(contact.getNicknameFamilyName().orElse(""), nicknameFamilyName) &&
Objects.equals(contact.getNote().orElse(""), note);
}
}

View File

@@ -0,0 +1,267 @@
package org.thoughtcrime.securesms.storage
import okio.ByteString
import okio.ByteString.Companion.toByteString
import org.signal.core.util.isEmpty
import org.signal.core.util.isNotEmpty
import org.signal.core.util.logging.Log
import org.signal.core.util.nullIfBlank
import org.signal.core.util.nullIfEmpty
import org.thoughtcrime.securesms.database.RecipientTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.database.model.RecipientRecord
import org.thoughtcrime.securesms.jobs.RetrieveProfileJob.Companion.enqueue
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.recipients.Recipient.Companion.trustedPush
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.storage.StorageSyncModels.localToRemoteRecord
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import org.whispersystems.signalservice.api.storage.SignalContactRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.signalAci
import org.whispersystems.signalservice.api.storage.signalPni
import org.whispersystems.signalservice.api.storage.toSignalContactRecord
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState
import java.io.IOException
import java.util.Optional
import java.util.regex.Pattern
/**
* Record processor for [SignalContactRecord].
* Handles merging and updating our local store when processing remote contact storage records.
*/
class ContactRecordProcessor(
private val selfAci: ACI?,
private val selfPni: PNI?,
private val selfE164: String?,
private val recipientTable: RecipientTable
) : DefaultStorageRecordProcessor<SignalContactRecord>() {
companion object {
private val TAG = Log.tag(ContactRecordProcessor::class.java)
private val E164_PATTERN: Pattern = Pattern.compile("^\\+[1-9]\\d{0,18}$")
private fun isValidE164(value: String): Boolean {
return E164_PATTERN.matcher(value).matches()
}
}
constructor() : this(
selfAci = SignalStore.account.aci,
selfPni = SignalStore.account.pni,
selfE164 = SignalStore.account.e164,
recipientTable = SignalDatabase.recipients
)
/**
* For contact records specifically, we have some extra work that needs to be done before we process all of the records.
*
* We have to find all unregistered ACI-only records and split them into two separate contact rows locally, if necessary.
* The reasons are nuanced, but the TL;DR is that we want to split unregistered users into separate rows so that a user
* could re-register and get a different ACI.
*/
@Throws(IOException::class)
override fun process(remoteRecords: Collection<SignalContactRecord>, keyGenerator: StorageKeyGenerator) {
val unregisteredAciOnly: MutableList<SignalContactRecord> = ArrayList()
for (remoteRecord in remoteRecords) {
if (isInvalid(remoteRecord)) {
continue
}
if (remoteRecord.proto.unregisteredAtTimestamp > 0 && remoteRecord.proto.signalAci != null && remoteRecord.proto.signalPni == null && remoteRecord.proto.e164.isBlank()) {
unregisteredAciOnly.add(remoteRecord)
}
}
if (unregisteredAciOnly.size > 0) {
for (aciOnly in unregisteredAciOnly) {
SignalDatabase.recipients.splitForStorageSyncIfNecessary(aciOnly.proto.signalAci!!)
}
}
super.process(remoteRecords, keyGenerator)
}
/**
* Error cases:
* - You can't have a contact record without an ACI or PNI.
* - You can't have a contact record for yourself. That should be an account record.
*
* Note: This method could be written more succinctly, but the logs are useful :)
*/
override fun isInvalid(remote: SignalContactRecord): Boolean {
val hasAci = remote.proto.signalAci?.isValid == true
val hasPni = remote.proto.signalPni?.isValid == true
if (!hasAci && !hasPni) {
Log.w(TAG, "Found a ContactRecord with neither an ACI nor a PNI -- marking as invalid.")
return true
} else if (selfAci != null && selfAci == remote.proto.signalAci ||
(selfPni != null && selfPni == remote.proto.signalPni) ||
(selfE164 != null && remote.proto.e164.isNotBlank() && remote.proto.e164 == selfE164)
) {
Log.w(TAG, "Found a ContactRecord for ourselves -- marking as invalid.")
return true
} else if (remote.proto.e164.isNotBlank() && !isValidE164(remote.proto.e164)) {
Log.w(TAG, "Found a record with an invalid E164. Marking as invalid.")
return true
} else {
return false
}
}
override fun getMatching(remote: SignalContactRecord, keyGenerator: StorageKeyGenerator): Optional<SignalContactRecord> {
var found: Optional<RecipientId> = remote.proto.signalAci?.let { recipientTable.getByAci(it) } ?: Optional.empty()
if (found.isEmpty && remote.proto.e164.isNotBlank()) {
found = recipientTable.getByE164(remote.proto.e164)
}
if (found.isEmpty && remote.proto.signalPni != null) {
found = recipientTable.getByPni(remote.proto.signalPni!!)
}
return found
.map { recipientTable.getRecordForSync(it)!! }
.map { settings: RecipientRecord ->
if (settings.storageId != null) {
return@map localToRemoteRecord(settings)
} else {
Log.w(TAG, "Newly discovering a registered user via storage service. Saving a storageId for them.")
recipientTable.updateStorageId(settings.id, keyGenerator.generate())
val updatedSettings = recipientTable.getRecordForSync(settings.id)!!
return@map localToRemoteRecord(updatedSettings)
}
}
.map { record -> SignalContactRecord(record.id, record.proto.contact!!) }
}
override fun merge(remote: SignalContactRecord, local: SignalContactRecord, keyGenerator: StorageKeyGenerator): SignalContactRecord {
val mergedProfileGivenName: String
val mergedProfileFamilyName: String
val localAci = local.proto.signalAci
val localPni = local.proto.signalPni
val remoteAci = remote.proto.signalAci
val remotePni = remote.proto.signalPni
if (remote.proto.givenName.isNotBlank() || remote.proto.familyName.isNotBlank()) {
mergedProfileGivenName = remote.proto.givenName
mergedProfileFamilyName = remote.proto.familyName
} else {
mergedProfileGivenName = local.proto.givenName
mergedProfileFamilyName = local.proto.familyName
}
val mergedIdentityState: IdentityState
val mergedIdentityKey: ByteArray?
if ((remote.proto.identityState != local.proto.identityState && remote.proto.identityKey.isNotEmpty()) ||
(remote.proto.identityKey.isNotEmpty() && local.proto.identityKey.isEmpty()) ||
(remote.proto.identityKey.isNotEmpty() && local.proto.unregisteredAtTimestamp > 0)
) {
mergedIdentityState = remote.proto.identityState
mergedIdentityKey = remote.proto.identityKey.takeIf { it.isNotEmpty() }?.toByteArray()
} else {
mergedIdentityState = local.proto.identityState
mergedIdentityKey = local.proto.identityKey.takeIf { it.isNotEmpty() }?.toByteArray()
}
if (localAci != null && mergedIdentityKey != null && remote.proto.identityKey.isNotEmpty() && !mergedIdentityKey.contentEquals(remote.proto.identityKey.toByteArray())) {
Log.w(TAG, "The local and remote identity keys do not match for " + localAci + ". Enqueueing a profile fetch.")
enqueue(trustedPush(localAci, localPni, local.proto.e164).id)
}
val mergedPni: PNI?
val mergedE164: String?
val e164sMatchButPnisDont = local.proto.e164.isNotBlank() &&
local.proto.e164 == remote.proto.e164 &&
localPni != null &&
remotePni != null &&
localPni != remotePni
val pnisMatchButE164sDont = localPni != null &&
localPni == remotePni &&
local.proto.e164.isNotBlank() &&
remote.proto.e164.isNotBlank() &&
local.proto.e164 != remote.proto.e164
if (e164sMatchButPnisDont) {
Log.w(TAG, "Matching E164s, but the PNIs differ! Trusting our local pair.")
// TODO [pnp] Schedule CDS fetch?
mergedPni = localPni
mergedE164 = local.proto.e164
} else if (pnisMatchButE164sDont) {
Log.w(TAG, "Matching PNIs, but the E164s differ! Trusting our local pair.")
// TODO [pnp] Schedule CDS fetch?
mergedPni = localPni
mergedE164 = local.proto.e164
} else {
mergedPni = remotePni ?: localPni
mergedE164 = remote.proto.e164.nullIfBlank() ?: local.proto.e164.nullIfBlank()
}
val merged = SignalContactRecord.newBuilder(remote.serializedUnknowns).apply {
e164 = mergedE164 ?: ""
aci = local.proto.aci.nullIfBlank() ?: remote.proto.aci
pni = mergedPni?.toStringWithoutPrefix() ?: ""
givenName = mergedProfileGivenName
familyName = mergedProfileFamilyName
profileKey = remote.proto.profileKey.nullIfEmpty() ?: local.proto.profileKey
username = remote.proto.username.nullIfBlank() ?: local.proto.username
identityState = mergedIdentityState
identityKey = mergedIdentityKey?.toByteString() ?: ByteString.EMPTY
blocked = remote.proto.blocked
whitelisted = remote.proto.whitelisted
archived = remote.proto.archived
markedUnread = remote.proto.markedUnread
mutedUntilTimestamp = remote.proto.mutedUntilTimestamp
hideStory = remote.proto.hideStory
unregisteredAtTimestamp = remote.proto.unregisteredAtTimestamp
hidden = remote.proto.hidden
systemGivenName = if (SignalStore.account.isPrimaryDevice) local.proto.systemGivenName else remote.proto.systemGivenName
systemFamilyName = if (SignalStore.account.isPrimaryDevice) local.proto.systemFamilyName else remote.proto.systemFamilyName
systemNickname = remote.proto.systemNickname
nickname = remote.proto.nickname
pniSignatureVerified = remote.proto.pniSignatureVerified || local.proto.pniSignatureVerified
note = remote.proto.note.nullIfBlank() ?: local.proto.note
}.build().toSignalContactRecord(StorageId.forContact(keyGenerator.generate()))
val matchesRemote = doParamsMatch(remote, merged)
val matchesLocal = doParamsMatch(local, merged)
return if (matchesRemote) {
remote
} else if (matchesLocal) {
local
} else {
merged
}
}
override fun insertLocal(record: SignalContactRecord) {
recipientTable.applyStorageSyncContactInsert(record)
}
override fun updateLocal(update: StorageRecordUpdate<SignalContactRecord>) {
recipientTable.applyStorageSyncContactUpdate(update)
}
override fun compare(lhs: SignalContactRecord, rhs: SignalContactRecord): Int {
return if (
(lhs.proto.signalAci != null && lhs.proto.aci == rhs.proto.aci) ||
(lhs.proto.e164.isNotBlank() && lhs.proto.e164 == rhs.proto.e164) ||
(lhs.proto.signalPni != null && lhs.proto.pni == rhs.proto.pni)
) {
0
} else {
1
}
}
}

View File

@@ -70,6 +70,10 @@ abstract class DefaultStorageRecordProcessor<E : SignalRecord<*>> : StorageRecor
} }
} }
fun doParamsMatch(base: E, test: E): Boolean {
return base.serializedUnknowns.contentEquals(test.serializedUnknowns) && base.proto == test.proto
}
private fun info(i: Int, record: E, message: String) { private fun info(i: Int, record: E, message: String) {
Log.i(TAG, "[$i][${record.javaClass.getSimpleName()}] $message") Log.i(TAG, "[$i][${record.javaClass.getSimpleName()}] $message")
} }

View File

@@ -9,11 +9,13 @@ import org.thoughtcrime.securesms.groups.BadGroupIdException
import org.thoughtcrime.securesms.groups.GroupId import org.thoughtcrime.securesms.groups.GroupId
import org.whispersystems.signalservice.api.storage.SignalGroupV1Record import org.whispersystems.signalservice.api.storage.SignalGroupV1Record
import org.whispersystems.signalservice.api.storage.SignalStorageRecord import org.whispersystems.signalservice.api.storage.SignalStorageRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.toSignalGroupV1Record import org.whispersystems.signalservice.api.storage.toSignalGroupV1Record
import java.util.Optional import java.util.Optional
/** /**
* Handles merging remote storage updates into local group v1 state. * Record processor for [SignalGroupV1Record].
* Handles merging and updating our local store when processing remote gv1 storage records.
*/ */
class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val recipientTable: RecipientTable) : DefaultStorageRecordProcessor<SignalGroupV1Record>() { class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val recipientTable: RecipientTable) : DefaultStorageRecordProcessor<SignalGroupV1Record>() {
companion object { companion object {
@@ -31,7 +33,7 @@ class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val
*/ */
override fun isInvalid(remote: SignalGroupV1Record): Boolean { override fun isInvalid(remote: SignalGroupV1Record): Boolean {
try { try {
val id = GroupId.v1(remote.groupId) val id = GroupId.v1(remote.proto.id.toByteArray())
val v2Record = groupDatabase.getGroup(id.deriveV2MigrationGroupId()) val v2Record = groupDatabase.getGroup(id.deriveV2MigrationGroupId())
if (v2Record.isPresent) { if (v2Record.isPresent) {
@@ -47,7 +49,7 @@ class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val
} }
override fun getMatching(remote: SignalGroupV1Record, keyGenerator: StorageKeyGenerator): Optional<SignalGroupV1Record> { override fun getMatching(remote: SignalGroupV1Record, keyGenerator: StorageKeyGenerator): Optional<SignalGroupV1Record> {
val groupId = GroupId.v1orThrow(remote.groupId) val groupId = GroupId.v1orThrow(remote.proto.id.toByteArray())
val recipientId = recipientTable.getByGroupId(groupId) val recipientId = recipientTable.getByGroupId(groupId)
@@ -58,28 +60,24 @@ class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val
} }
override fun merge(remote: SignalGroupV1Record, local: SignalGroupV1Record, keyGenerator: StorageKeyGenerator): SignalGroupV1Record { override fun merge(remote: SignalGroupV1Record, local: SignalGroupV1Record, keyGenerator: StorageKeyGenerator): SignalGroupV1Record {
val unknownFields = remote.serializeUnknownFields() val merged = SignalGroupV1Record.newBuilder(remote.serializedUnknowns).apply {
val blocked = remote.isBlocked id = remote.proto.id
val profileSharing = remote.isProfileSharingEnabled blocked = remote.proto.blocked
val archived = remote.isArchived whitelisted = remote.proto.whitelisted
val forcedUnread = remote.isForcedUnread archived = remote.proto.archived
val muteUntil = remote.muteUntil markedUnread = remote.proto.markedUnread
mutedUntilTimestamp = remote.proto.mutedUntilTimestamp
}.build().toSignalGroupV1Record(StorageId.forGroupV1(keyGenerator.generate()))
val matchesRemote = doParamsMatch(group = remote, unknownFields = unknownFields, blocked = blocked, profileSharing = profileSharing, archived = archived, forcedUnread = forcedUnread, muteUntil = muteUntil) val matchesRemote = doParamsMatch(remote, merged)
val matchesLocal = doParamsMatch(group = local, unknownFields = unknownFields, blocked = blocked, profileSharing = profileSharing, archived = archived, forcedUnread = forcedUnread, muteUntil = muteUntil) val matchesLocal = doParamsMatch(local, merged)
return if (matchesRemote) { return if (matchesRemote) {
remote remote
} else if (matchesLocal) { } else if (matchesLocal) {
local local
} else { } else {
SignalGroupV1Record.Builder(keyGenerator.generate(), remote.groupId, unknownFields) merged
.setBlocked(blocked)
.setProfileSharingEnabled(profileSharing)
.setArchived(archived)
.setForcedUnread(forcedUnread)
.setMuteUntil(muteUntil)
.build()
} }
} }
@@ -92,27 +90,10 @@ class GroupV1RecordProcessor(private val groupDatabase: GroupTable, private val
} }
override fun compare(lhs: SignalGroupV1Record, rhs: SignalGroupV1Record): Int { override fun compare(lhs: SignalGroupV1Record, rhs: SignalGroupV1Record): Int {
return if (lhs.groupId.contentEquals(rhs.groupId)) { return if (lhs.proto.id == rhs.proto.id) {
0 0
} else { } else {
1 1
} }
} }
private fun doParamsMatch(
group: SignalGroupV1Record,
unknownFields: ByteArray?,
blocked: Boolean,
profileSharing: Boolean,
archived: Boolean,
forcedUnread: Boolean,
muteUntil: Long
): Boolean {
return unknownFields.contentEquals(group.serializeUnknownFields()) &&
blocked == group.isBlocked &&
profileSharing == group.isProfileSharingEnabled &&
archived == group.isArchived &&
forcedUnread == group.isForcedUnread &&
muteUntil == group.muteUntil
}
} }

View File

@@ -9,10 +9,14 @@ import org.thoughtcrime.securesms.database.model.RecipientRecord
import org.thoughtcrime.securesms.groups.GroupId import org.thoughtcrime.securesms.groups.GroupId
import org.whispersystems.signalservice.api.storage.SignalGroupV2Record import org.whispersystems.signalservice.api.storage.SignalGroupV2Record
import org.whispersystems.signalservice.api.storage.SignalStorageRecord import org.whispersystems.signalservice.api.storage.SignalStorageRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.toSignalGroupV2Record import org.whispersystems.signalservice.api.storage.toSignalGroupV2Record
import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record
import java.util.Optional import java.util.Optional
/**
* Record processor for [SignalGroupV2Record].
* Handles merging and updating our local store when processing remote gv2 storage records.
*/
class GroupV2RecordProcessor(private val recipientTable: RecipientTable, private val groupDatabase: GroupTable) : DefaultStorageRecordProcessor<SignalGroupV2Record>() { class GroupV2RecordProcessor(private val recipientTable: RecipientTable, private val groupDatabase: GroupTable) : DefaultStorageRecordProcessor<SignalGroupV2Record>() {
companion object { companion object {
private val TAG = Log.tag(GroupV2RecordProcessor::class.java) private val TAG = Log.tag(GroupV2RecordProcessor::class.java)
@@ -21,11 +25,11 @@ class GroupV2RecordProcessor(private val recipientTable: RecipientTable, private
constructor() : this(SignalDatabase.recipients, SignalDatabase.groups) constructor() : this(SignalDatabase.recipients, SignalDatabase.groups)
override fun isInvalid(remote: SignalGroupV2Record): Boolean { override fun isInvalid(remote: SignalGroupV2Record): Boolean {
return remote.masterKeyBytes.size != GroupMasterKey.SIZE return remote.proto.masterKey.size != GroupMasterKey.SIZE
} }
override fun getMatching(remote: SignalGroupV2Record, keyGenerator: StorageKeyGenerator): Optional<SignalGroupV2Record> { override fun getMatching(remote: SignalGroupV2Record, keyGenerator: StorageKeyGenerator): Optional<SignalGroupV2Record> {
val groupId = GroupId.v2(remote.masterKeyOrThrow) val groupId = GroupId.v2(GroupMasterKey(remote.proto.masterKey.toByteArray()))
val recipientId = recipientTable.getByGroupId(groupId) val recipientId = recipientTable.getByGroupId(groupId)
@@ -36,64 +40,35 @@ class GroupV2RecordProcessor(private val recipientTable: RecipientTable, private
StorageSyncModels.localToRemoteRecord(settings) StorageSyncModels.localToRemoteRecord(settings)
} else { } else {
Log.w(TAG, "No local master key. Assuming it matches remote since the groupIds match. Enqueuing a fetch to fix the bad state.") Log.w(TAG, "No local master key. Assuming it matches remote since the groupIds match. Enqueuing a fetch to fix the bad state.")
groupDatabase.fixMissingMasterKey(remote.masterKeyOrThrow) groupDatabase.fixMissingMasterKey(GroupMasterKey(remote.proto.masterKey.toByteArray()))
StorageSyncModels.localToRemoteRecord(settings, remote.masterKeyOrThrow) StorageSyncModels.localToRemoteRecord(settings, GroupMasterKey(remote.proto.masterKey.toByteArray()))
} }
} }
.map { record: SignalStorageRecord -> record.proto.groupV2!!.toSignalGroupV2Record(record.id) } .map { record: SignalStorageRecord -> record.proto.groupV2!!.toSignalGroupV2Record(record.id) }
} }
override fun merge(remote: SignalGroupV2Record, local: SignalGroupV2Record, keyGenerator: StorageKeyGenerator): SignalGroupV2Record { override fun merge(remote: SignalGroupV2Record, local: SignalGroupV2Record, keyGenerator: StorageKeyGenerator): SignalGroupV2Record {
val unknownFields = remote.serializeUnknownFields() val merged = SignalGroupV2Record.newBuilder(remote.serializedUnknowns).apply {
val blocked = remote.isBlocked masterKey = remote.proto.masterKey
val profileSharing = remote.isProfileSharingEnabled blocked = remote.proto.blocked
val archived = remote.isArchived whitelisted = remote.proto.whitelisted
val forcedUnread = remote.isForcedUnread archived = remote.proto.archived
val muteUntil = remote.muteUntil markedUnread = remote.proto.markedUnread
val notifyForMentionsWhenMuted = remote.notifyForMentionsWhenMuted() mutedUntilTimestamp = remote.proto.mutedUntilTimestamp
val hideStory = remote.shouldHideStory() dontNotifyForMentionsIfMuted = remote.proto.dontNotifyForMentionsIfMuted
val storySendMode = remote.storySendMode hideStory = remote.proto.hideStory
storySendMode = remote.proto.storySendMode
}.build().toSignalGroupV2Record(StorageId.forGroupV2(keyGenerator.generate()))
val matchesRemote = doParamsMatch( val matchesRemote = doParamsMatch(remote, merged)
group = remote, val matchesLocal = doParamsMatch(local, merged)
unknownFields = unknownFields,
blocked = blocked,
profileSharing = profileSharing,
archived = archived,
forcedUnread = forcedUnread,
muteUntil = muteUntil,
notifyForMentionsWhenMuted = notifyForMentionsWhenMuted,
hideStory = hideStory,
storySendMode = storySendMode
)
val matchesLocal = doParamsMatch(
group = local,
unknownFields = unknownFields,
blocked = blocked,
profileSharing = profileSharing,
archived = archived,
forcedUnread = forcedUnread,
muteUntil = muteUntil,
notifyForMentionsWhenMuted = notifyForMentionsWhenMuted,
hideStory = hideStory,
storySendMode = storySendMode
)
return if (matchesRemote) { return if (matchesRemote) {
remote remote
} else if (matchesLocal) { } else if (matchesLocal) {
local local
} else { } else {
SignalGroupV2Record.Builder(keyGenerator.generate(), remote.masterKeyBytes, unknownFields) merged
.setBlocked(blocked)
.setProfileSharingEnabled(profileSharing)
.setArchived(archived)
.setForcedUnread(forcedUnread)
.setMuteUntil(muteUntil)
.setNotifyForMentionsWhenMuted(notifyForMentionsWhenMuted)
.setHideStory(hideStory)
.setStorySendMode(storySendMode)
.build()
} }
} }
@@ -106,33 +81,10 @@ class GroupV2RecordProcessor(private val recipientTable: RecipientTable, private
} }
override fun compare(lhs: SignalGroupV2Record, rhs: SignalGroupV2Record): Int { override fun compare(lhs: SignalGroupV2Record, rhs: SignalGroupV2Record): Int {
return if (lhs.masterKeyBytes.contentEquals(rhs.masterKeyBytes)) { return if (lhs.proto.masterKey == rhs.proto.masterKey) {
0 0
} else { } else {
1 1
} }
} }
private fun doParamsMatch(
group: SignalGroupV2Record,
unknownFields: ByteArray?,
blocked: Boolean,
profileSharing: Boolean,
archived: Boolean,
forcedUnread: Boolean,
muteUntil: Long,
notifyForMentionsWhenMuted: Boolean,
hideStory: Boolean,
storySendMode: GroupV2Record.StorySendMode
): Boolean {
return unknownFields.contentEquals(group.serializeUnknownFields()) &&
blocked == group.isBlocked &&
profileSharing == group.isProfileSharingEnabled &&
archived == group.isArchived &&
forcedUnread == group.isForcedUnread &&
muteUntil == group.muteUntil &&
notifyForMentionsWhenMuted == group.notifyForMentionsWhenMuted() &&
hideStory == group.shouldHideStory() &&
storySendMode == group.storySendMode
}
} }

View File

@@ -34,7 +34,6 @@ import org.whispersystems.signalservice.api.storage.safeSetPayments
import org.whispersystems.signalservice.api.storage.safeSetSubscriber import org.whispersystems.signalservice.api.storage.safeSetSubscriber
import org.whispersystems.signalservice.api.storage.toSignalAccountRecord import org.whispersystems.signalservice.api.storage.toSignalAccountRecord
import org.whispersystems.signalservice.api.storage.toSignalStorageRecord import org.whispersystems.signalservice.api.storage.toSignalStorageRecord
import org.whispersystems.signalservice.api.util.OptionalUtil.byteArrayEquals
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import org.whispersystems.signalservice.api.util.toByteArray import org.whispersystems.signalservice.api.util.toByteArray
import org.whispersystems.signalservice.internal.storage.protos.AccountRecord import org.whispersystems.signalservice.internal.storage.protos.AccountRecord
@@ -105,7 +104,7 @@ object StorageSyncHelper {
@JvmStatic @JvmStatic
fun profileKeyChanged(update: StorageRecordUpdate<SignalContactRecord>): Boolean { fun profileKeyChanged(update: StorageRecordUpdate<SignalContactRecord>): Boolean {
return !byteArrayEquals(update.old.profileKey, update.new.profileKey) return update.old.proto.profileKey != update.new.proto.profileKey
} }
@JvmStatic @JvmStatic

View File

@@ -18,17 +18,23 @@ import org.thoughtcrime.securesms.database.model.RecipientRecord
import org.thoughtcrime.securesms.database.model.databaseprotos.InAppPaymentData import org.thoughtcrime.securesms.database.model.databaseprotos.InAppPaymentData
import org.thoughtcrime.securesms.keyvalue.PhoneNumberPrivacyValues import org.thoughtcrime.securesms.keyvalue.PhoneNumberPrivacyValues
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.api.storage.SignalCallLinkRecord import org.whispersystems.signalservice.api.storage.SignalCallLinkRecord
import org.whispersystems.signalservice.api.storage.SignalContactRecord import org.whispersystems.signalservice.api.storage.SignalContactRecord
import org.whispersystems.signalservice.api.storage.SignalGroupV1Record import org.whispersystems.signalservice.api.storage.SignalGroupV1Record
import org.whispersystems.signalservice.api.storage.SignalGroupV2Record import org.whispersystems.signalservice.api.storage.SignalGroupV2Record
import org.whispersystems.signalservice.api.storage.SignalStorageRecord import org.whispersystems.signalservice.api.storage.SignalStorageRecord
import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.toSignalCallLinkRecord
import org.whispersystems.signalservice.api.storage.toSignalContactRecord
import org.whispersystems.signalservice.api.storage.toSignalGroupV1Record
import org.whispersystems.signalservice.api.storage.toSignalGroupV2Record
import org.whispersystems.signalservice.api.storage.toSignalStorageRecord import org.whispersystems.signalservice.api.storage.toSignalStorageRecord
import org.whispersystems.signalservice.api.storage.toSignalStoryDistributionListRecord
import org.whispersystems.signalservice.api.subscriptions.SubscriberId import org.whispersystems.signalservice.api.subscriptions.SubscriberId
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import org.whispersystems.signalservice.internal.storage.protos.AccountRecord import org.whispersystems.signalservice.internal.storage.protos.AccountRecord
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState
import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record
import java.util.Currency import java.util.Currency
@@ -150,33 +156,31 @@ object StorageSyncModels {
throw AssertionError("Must have either a UUID or a phone number!") throw AssertionError("Must have either a UUID or a phone number!")
} }
val hideStory = recipient.extras != null && recipient.extras.hideStory() return SignalContactRecord.newBuilder(recipient.syncExtras.storageProto).apply {
aci = recipient.aci?.toString() ?: ""
return SignalContactRecord.Builder(rawStorageId, recipient.aci, recipient.syncExtras.storageProto) e164 = recipient.e164 ?: ""
.setE164(recipient.e164) pni = recipient.pni?.toStringWithoutPrefix() ?: ""
.setPni(recipient.pni) profileKey = recipient.profileKey?.toByteString() ?: ByteString.EMPTY
.setProfileKey(recipient.profileKey) givenName = recipient.signalProfileName.givenName
.setProfileGivenName(recipient.signalProfileName.givenName) familyName = recipient.signalProfileName.familyName
.setProfileFamilyName(recipient.signalProfileName.familyName) systemGivenName = recipient.systemProfileName.givenName
.setSystemGivenName(recipient.systemProfileName.givenName) systemFamilyName = recipient.systemProfileName.familyName
.setSystemFamilyName(recipient.systemProfileName.familyName) systemNickname = recipient.syncExtras.systemNickname ?: ""
.setSystemNickname(recipient.syncExtras.systemNickname) blocked = recipient.isBlocked
.setBlocked(recipient.isBlocked) whitelisted = recipient.profileSharing || recipient.systemContactUri != null
.setProfileSharingEnabled(recipient.profileSharing || recipient.systemContactUri != null) identityKey = recipient.syncExtras.identityKey?.toByteString() ?: ByteString.EMPTY
.setIdentityKey(recipient.syncExtras.identityKey) identityState = localToRemoteIdentityState(recipient.syncExtras.identityStatus)
.setIdentityState(localToRemoteIdentityState(recipient.syncExtras.identityStatus)) archived = recipient.syncExtras.isArchived
.setArchived(recipient.syncExtras.isArchived) markedUnread = recipient.syncExtras.isForcedUnread
.setForcedUnread(recipient.syncExtras.isForcedUnread) mutedUntilTimestamp = recipient.muteUntil
.setMuteUntil(recipient.muteUntil) hideStory = recipient.extras != null && recipient.extras.hideStory()
.setHideStory(hideStory) unregisteredAtTimestamp = recipient.syncExtras.unregisteredTimestamp
.setUnregisteredTimestamp(recipient.syncExtras.unregisteredTimestamp) hidden = recipient.hiddenState != Recipient.HiddenState.NOT_HIDDEN
.setHidden(recipient.hiddenState != Recipient.HiddenState.NOT_HIDDEN) username = recipient.username ?: ""
.setUsername(recipient.username) pniSignatureVerified = recipient.syncExtras.pniSignatureVerified
.setPniSignatureVerified(recipient.syncExtras.pniSignatureVerified) nickname = recipient.nickname.takeUnless { it.isEmpty }?.let { ContactRecord.Name(given = it.givenName, family = it.familyName) }
.setNicknameGivenName(recipient.nickname.givenName) note = recipient.note ?: ""
.setNicknameFamilyName(recipient.nickname.familyName) }.build().toSignalContactRecord(StorageId.forContact(rawStorageId))
.setNote(recipient.note)
.build()
} }
private fun localToRemoteGroupV1(recipient: RecipientRecord, rawStorageId: ByteArray): SignalGroupV1Record { private fun localToRemoteGroupV1(recipient: RecipientRecord, rawStorageId: ByteArray): SignalGroupV1Record {
@@ -186,13 +190,14 @@ object StorageSyncModels {
throw AssertionError("Group is not V1") throw AssertionError("Group is not V1")
} }
return SignalGroupV1Record.Builder(rawStorageId, groupId.decodedId, recipient.syncExtras.storageProto) return SignalGroupV1Record.newBuilder(recipient.syncExtras.storageProto).apply {
.setBlocked(recipient.isBlocked) id = recipient.groupId.requireV1().decodedId.toByteString()
.setProfileSharingEnabled(recipient.profileSharing) blocked = recipient.isBlocked
.setArchived(recipient.syncExtras.isArchived) whitelisted = recipient.profileSharing
.setForcedUnread(recipient.syncExtras.isForcedUnread) archived = recipient.syncExtras.isArchived
.setMuteUntil(recipient.muteUntil) markedUnread = recipient.syncExtras.isForcedUnread
.build() mutedUntilTimestamp = recipient.muteUntil
}.build().toSignalGroupV1Record(StorageId.forGroupV1(rawStorageId))
} }
private fun localToRemoteGroupV2(recipient: RecipientRecord, rawStorageId: ByteArray?, groupMasterKey: GroupMasterKey): SignalGroupV2Record { private fun localToRemoteGroupV2(recipient: RecipientRecord, rawStorageId: ByteArray?, groupMasterKey: GroupMasterKey): SignalGroupV2Record {
@@ -202,29 +207,21 @@ object StorageSyncModels {
throw AssertionError("Group is not V2") throw AssertionError("Group is not V2")
} }
if (groupMasterKey == null) { return SignalGroupV2Record.newBuilder(recipient.syncExtras.storageProto).apply {
throw AssertionError("Group master key not on recipient record") masterKey = groupMasterKey.serialize().toByteString()
} blocked = recipient.isBlocked
whitelisted = recipient.profileSharing
val hideStory = recipient.extras != null && recipient.extras.hideStory() archived = recipient.syncExtras.isArchived
val showAsStoryState = groups.getShowAsStoryState(groupId) markedUnread = recipient.syncExtras.isForcedUnread
mutedUntilTimestamp = recipient.muteUntil
val storySendMode = when (showAsStoryState) { dontNotifyForMentionsIfMuted = recipient.mentionSetting == RecipientTable.MentionSetting.ALWAYS_NOTIFY
hideStory = recipient.extras != null && recipient.extras.hideStory()
storySendMode = when (groups.getShowAsStoryState(groupId)) {
ShowAsStoryState.ALWAYS -> GroupV2Record.StorySendMode.ENABLED ShowAsStoryState.ALWAYS -> GroupV2Record.StorySendMode.ENABLED
ShowAsStoryState.NEVER -> GroupV2Record.StorySendMode.DISABLED ShowAsStoryState.NEVER -> GroupV2Record.StorySendMode.DISABLED
else -> GroupV2Record.StorySendMode.DEFAULT else -> GroupV2Record.StorySendMode.DEFAULT
} }
}.build().toSignalGroupV2Record(StorageId.forGroupV2(rawStorageId))
return SignalGroupV2Record.Builder(rawStorageId, groupMasterKey, recipient.syncExtras.storageProto)
.setBlocked(recipient.isBlocked)
.setProfileSharingEnabled(recipient.profileSharing)
.setArchived(recipient.syncExtras.isArchived)
.setForcedUnread(recipient.syncExtras.isForcedUnread)
.setMuteUntil(recipient.muteUntil)
.setNotifyForMentionsWhenMuted(recipient.mentionSetting == RecipientTable.MentionSetting.ALWAYS_NOTIFY)
.setHideStory(hideStory)
.setStorySendMode(storySendMode)
.build()
} }
private fun localToRemoteCallLink(recipient: RecipientRecord, rawStorageId: ByteArray): SignalCallLinkRecord { private fun localToRemoteCallLink(recipient: RecipientRecord, rawStorageId: ByteArray): SignalCallLinkRecord {
@@ -239,11 +236,11 @@ object StorageSyncModels {
val deletedTimestamp = max(0.0, callLinks.getDeletedTimestampByRoomId(callLinkRoomId).toDouble()).toLong() val deletedTimestamp = max(0.0, callLinks.getDeletedTimestampByRoomId(callLinkRoomId).toDouble()).toLong()
val adminPassword = if (deletedTimestamp > 0) byteArrayOf() else callLink.credentials.adminPassBytes!! val adminPassword = if (deletedTimestamp > 0) byteArrayOf() else callLink.credentials.adminPassBytes!!
return SignalCallLinkRecord.Builder(rawStorageId, null) return SignalCallLinkRecord.newBuilder(null).apply {
.setRootKey(callLink.credentials.linkKeyBytes) rootKey = callLink.credentials.linkKeyBytes.toByteString()
.setAdminPassKey(adminPassword) adminPasskey = adminPassword.toByteString()
.setDeletedTimestamp(deletedTimestamp) deletedAtTimestampMs = deletedTimestamp
.build() }.build().toSignalCallLinkRecord(StorageId.forCallLink(rawStorageId))
} }
private fun localToRemoteStoryDistributionList(recipient: RecipientRecord, rawStorageId: ByteArray): SignalStoryDistributionListRecord { private fun localToRemoteStoryDistributionList(recipient: RecipientRecord, rawStorageId: ByteArray): SignalStoryDistributionListRecord {
@@ -252,25 +249,22 @@ object StorageSyncModels {
val record = distributionLists.getListForStorageSync(distributionListId) ?: throw AssertionError("Must have a distribution list record!") val record = distributionLists.getListForStorageSync(distributionListId) ?: throw AssertionError("Must have a distribution list record!")
if (record.deletedAtTimestamp > 0L) { if (record.deletedAtTimestamp > 0L) {
return SignalStoryDistributionListRecord.Builder(rawStorageId, recipient.syncExtras.storageProto) return SignalStoryDistributionListRecord.newBuilder(recipient.syncExtras.storageProto).apply {
.setIdentifier(UuidUtil.toByteArray(record.distributionId.asUuid())) identifier = UuidUtil.toByteArray(record.distributionId.asUuid()).toByteString()
.setDeletedAtTimestamp(record.deletedAtTimestamp) deletedAtTimestamp = record.deletedAtTimestamp
.build() }.build().toSignalStoryDistributionListRecord(StorageId.forStoryDistributionList(rawStorageId))
} }
return SignalStoryDistributionListRecord.Builder(rawStorageId, recipient.syncExtras.storageProto) return SignalStoryDistributionListRecord.newBuilder(recipient.syncExtras.storageProto).apply {
.setIdentifier(UuidUtil.toByteArray(record.distributionId.asUuid())) identifier = UuidUtil.toByteArray(record.distributionId.asUuid()).toByteString()
.setName(record.name) name = record.name
.setRecipients( recipientServiceIds = record.getMembersToSync()
record.getMembersToSync()
.map { Recipient.resolved(it) } .map { Recipient.resolved(it) }
.filter { it.hasServiceId } .filter { it.hasServiceId }
.map { it.requireServiceId() } .map { it.requireServiceId().toString() }
.map { SignalServiceAddress(it) } allowsReplies = record.allowsReplies
) isBlockList = record.privacyMode.isBlockList
.setAllowsReplies(record.allowsReplies) }.build().toSignalStoryDistributionListRecord(StorageId.forStoryDistributionList(rawStorageId))
.setIsBlockList(record.privacyMode.isBlockList)
.build()
} }
fun remoteToLocalIdentityStatus(identityState: IdentityState): VerifiedStatus { fun remoteToLocalIdentityStatus(identityState: IdentityState): VerifiedStatus {

View File

@@ -5,14 +5,18 @@ import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.database.RecipientTable import org.thoughtcrime.securesms.database.RecipientTable
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.push.DistributionId import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord import org.whispersystems.signalservice.api.storage.SignalStoryDistributionListRecord
import org.whispersystems.signalservice.api.storage.StorageId
import org.whispersystems.signalservice.api.storage.toSignalStoryDistributionListRecord import org.whispersystems.signalservice.api.storage.toSignalStoryDistributionListRecord
import org.whispersystems.signalservice.api.util.OptionalUtil.asOptional import org.whispersystems.signalservice.api.util.OptionalUtil.asOptional
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import java.io.IOException import java.io.IOException
import java.util.Optional import java.util.Optional
/**
* Record processor for [SignalStoryDistributionListRecord].
* Handles merging and updating our local store when processing remote dlist storage records.
*/
class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<SignalStoryDistributionListRecord>() { class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<SignalStoryDistributionListRecord>() {
companion object { companion object {
@@ -28,7 +32,7 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
* - A non-visually-empty name field OR a deleted at timestamp * - A non-visually-empty name field OR a deleted at timestamp
*/ */
override fun isInvalid(remote: SignalStoryDistributionListRecord): Boolean { override fun isInvalid(remote: SignalStoryDistributionListRecord): Boolean {
val remoteUuid = UuidUtil.parseOrNull(remote.identifier) val remoteUuid = UuidUtil.parseOrNull(remote.proto.identifier)
if (remoteUuid == null) { if (remoteUuid == null) {
Log.d(TAG, "Bad distribution list identifier -- marking as invalid") Log.d(TAG, "Bad distribution list identifier -- marking as invalid")
return true return true
@@ -42,7 +46,7 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
haveSeenMyStory = haveSeenMyStory or isMyStory haveSeenMyStory = haveSeenMyStory or isMyStory
if (remote.deletedAtTimestamp > 0L) { if (remote.proto.deletedAtTimestamp > 0L) {
if (isMyStory) { if (isMyStory) {
Log.w(TAG, "Refusing to delete My Story -- marking as invalid") Log.w(TAG, "Refusing to delete My Story -- marking as invalid")
return true return true
@@ -51,7 +55,7 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
} }
} }
if (StringUtil.isVisuallyEmpty(remote.name)) { if (StringUtil.isVisuallyEmpty(remote.proto.name)) {
Log.d(TAG, "Bad distribution list name (visually empty) -- marking as invalid") Log.d(TAG, "Bad distribution list name (visually empty) -- marking as invalid")
return true return true
} }
@@ -62,7 +66,7 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
override fun getMatching(remote: SignalStoryDistributionListRecord, keyGenerator: StorageKeyGenerator): Optional<SignalStoryDistributionListRecord> { override fun getMatching(remote: SignalStoryDistributionListRecord, keyGenerator: StorageKeyGenerator): Optional<SignalStoryDistributionListRecord> {
Log.d(TAG, "Attempting to get matching record...") Log.d(TAG, "Attempting to get matching record...")
val matching = SignalDatabase.distributionLists.getRecipientIdForSyncRecord(remote) val matching = SignalDatabase.distributionLists.getRecipientIdForSyncRecord(remote)
if (matching == null && UuidUtil.parseOrThrow(remote.identifier) == DistributionId.MY_STORY.asUuid()) { if (matching == null && UuidUtil.parseOrThrow(remote.proto.identifier) == DistributionId.MY_STORY.asUuid()) {
Log.e(TAG, "Cannot find matching database record for My Story.") Log.e(TAG, "Cannot find matching database record for My Story.")
throw MyStoryDoesNotExistException() throw MyStoryDoesNotExistException()
} }
@@ -88,48 +92,24 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
} }
override fun merge(remote: SignalStoryDistributionListRecord, local: SignalStoryDistributionListRecord, keyGenerator: StorageKeyGenerator): SignalStoryDistributionListRecord { override fun merge(remote: SignalStoryDistributionListRecord, local: SignalStoryDistributionListRecord, keyGenerator: StorageKeyGenerator): SignalStoryDistributionListRecord {
val unknownFields = remote.serializeUnknownFields() val merged = SignalStoryDistributionListRecord.newBuilder(remote.serializedUnknowns).apply {
val identifier = remote.identifier identifier = remote.proto.identifier
val name = remote.name name = remote.proto.name
val recipients = remote.recipients recipientServiceIds = remote.proto.recipientServiceIds
val deletedAtTimestamp = remote.deletedAtTimestamp deletedAtTimestamp = remote.proto.deletedAtTimestamp
val allowsReplies = remote.allowsReplies() allowsReplies = remote.proto.allowsReplies
val isBlockList = remote.isBlockList isBlockList = remote.proto.isBlockList
}.build().toSignalStoryDistributionListRecord(StorageId.forStoryDistributionList(keyGenerator.generate()))
val matchesRemote = doParamsMatch( val matchesRemote = doParamsMatch(remote, merged)
record = remote, val matchesLocal = doParamsMatch(local, merged)
unknownFields = unknownFields,
identifier = identifier,
name = name,
recipients = recipients,
deletedAtTimestamp = deletedAtTimestamp,
allowsReplies = allowsReplies,
isBlockList = isBlockList
)
val matchesLocal = doParamsMatch(
record = local,
unknownFields = unknownFields,
identifier = identifier,
name = name,
recipients = recipients,
deletedAtTimestamp = deletedAtTimestamp,
allowsReplies = allowsReplies,
isBlockList = isBlockList
)
return if (matchesRemote) { return if (matchesRemote) {
remote remote
} else if (matchesLocal) { } else if (matchesLocal) {
local local
} else { } else {
SignalStoryDistributionListRecord.Builder(keyGenerator.generate(), unknownFields) merged
.setIdentifier(identifier)
.setName(name)
.setRecipients(recipients)
.setDeletedAtTimestamp(deletedAtTimestamp)
.setAllowsReplies(allowsReplies)
.setIsBlockList(isBlockList)
.build()
} }
} }
@@ -143,44 +123,19 @@ class StoryDistributionListRecordProcessor : DefaultStorageRecordProcessor<Signa
} }
override fun compare(o1: SignalStoryDistributionListRecord, o2: SignalStoryDistributionListRecord): Int { override fun compare(o1: SignalStoryDistributionListRecord, o2: SignalStoryDistributionListRecord): Int {
return if (o1.identifier.contentEquals(o2.identifier)) { return if (o1.proto.identifier == o2.proto.identifier) {
0 0
} else { } else {
1 1
} }
} }
private fun doParamsMatch(
record: SignalStoryDistributionListRecord,
unknownFields: ByteArray?,
identifier: ByteArray?,
name: String?,
recipients: List<SignalServiceAddress>,
deletedAtTimestamp: Long,
allowsReplies: Boolean,
isBlockList: Boolean
): Boolean {
return unknownFields.contentEquals(record.serializeUnknownFields()) &&
identifier.contentEquals(record.identifier) &&
name == record.name &&
recipients == record.recipients &&
deletedAtTimestamp == record.deletedAtTimestamp &&
allowsReplies == record.allowsReplies() &&
isBlockList == record.isBlockList
}
/** /**
* Thrown when the RecipientSettings object for a given distribution list is not the * Thrown when the RecipientSettings object for a given distribution list is not the
* correct group type (4). * correct group type (4).
*/ */
private class InvalidGroupTypeException : RuntimeException() private class InvalidGroupTypeException : RuntimeException()
/**
* Thrown when the distribution list object returned from the storage sync helper is
* absent, even though a RecipientSettings was found.
*/
private class UnexpectedEmptyOptionalException : RuntimeException()
/** /**
* Thrown when we try to ge the matching record for the "My Story" distribution ID but * Thrown when we try to ge the matching record for the "My Story" distribution ID but
* it isn't in the database. * it isn't in the database.

View File

@@ -307,9 +307,9 @@ class ContactRecordProcessorTest {
val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C)) val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C))
// THEN // THEN
assertEquals(local.aci, result.aci) assertEquals(local.proto.aci, result.proto.aci)
assertEquals(local.number.get(), result.number.get()) assertEquals(local.proto.e164, result.proto.e164)
assertEquals(local.pni.get(), result.pni.get()) assertEquals(local.proto.pni, result.proto.pni)
} }
@Test @Test
@@ -339,9 +339,9 @@ class ContactRecordProcessorTest {
val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C)) val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C))
// THEN // THEN
assertEquals(local.aci, result.aci) assertEquals(local.proto.aci, result.proto.aci)
assertEquals(local.number.get(), result.number.get()) assertEquals(local.proto.e164, result.proto.e164)
assertEquals(local.pni.get(), result.pni.get()) assertEquals(local.proto.pni, result.proto.pni)
} }
@Test @Test
@@ -371,9 +371,9 @@ class ContactRecordProcessorTest {
val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C)) val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C))
// THEN // THEN
assertEquals(remote.aci, result.aci) assertEquals(remote.proto.aci, result.proto.aci)
assertEquals(remote.number.get(), result.number.get()) assertEquals(remote.proto.e164, result.proto.e164)
assertEquals(remote.pni.get(), result.pni.get()) assertEquals(remote.proto.pni, result.proto.pni)
} }
@Test @Test
@@ -403,9 +403,9 @@ class ContactRecordProcessorTest {
val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C)) val result = subject.merge(remote, local, TestKeyGenerator(STORAGE_ID_C))
// THEN // THEN
assertEquals("Ghost", result.nicknameGivenName.get()) assertEquals("Ghost", result.proto.nickname?.given)
assertEquals("Spider", result.nicknameFamilyName.get()) assertEquals("Spider", result.proto.nickname?.family)
assertEquals("Spidey Friend", result.note.get()) assertEquals("Spidey Friend", result.proto.note)
} }
private fun buildRecord(id: StorageId = STORAGE_ID_A, record: ContactRecord): SignalContactRecord { private fun buildRecord(id: StorageId = STORAGE_ID_A, record: ContactRecord): SignalContactRecord {

View File

@@ -14,13 +14,10 @@ import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.storage.StorageSyncHelper.IdDifferenceResult; import org.thoughtcrime.securesms.storage.StorageSyncHelper.IdDifferenceResult;
import org.thoughtcrime.securesms.util.RemoteConfig; import org.thoughtcrime.securesms.util.RemoteConfig;
import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.storage.SignalAccountRecord;
import org.whispersystems.signalservice.api.storage.SignalContactRecord; import org.whispersystems.signalservice.api.storage.SignalContactRecord;
import org.whispersystems.signalservice.api.storage.SignalGroupV1Record;
import org.whispersystems.signalservice.api.storage.SignalGroupV2Record;
import org.whispersystems.signalservice.api.storage.SignalRecord; import org.whispersystems.signalservice.api.storage.SignalRecord;
import org.whispersystems.signalservice.api.storage.SignalStorageRecord;
import org.whispersystems.signalservice.api.storage.StorageId; import org.whispersystems.signalservice.api.storage.StorageId;
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
@@ -28,6 +25,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import okio.ByteString;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@@ -132,13 +131,16 @@ public final class StorageSyncHelperTest {
byte[] profileKey = new byte[32]; byte[] profileKey = new byte[32];
byte[] profileKeyCopy = profileKey.clone(); byte[] profileKeyCopy = profileKey.clone();
SignalContactRecord a = contactBuilder(1, ACI_A, E164_A, "a").setProfileKey(profileKey).build(); ContactRecord contactA = contactBuilder(ACI_A, E164_A, "a").profileKey(ByteString.of(profileKey)).build();
SignalContactRecord b = contactBuilder(1, ACI_A, E164_A, "a").setProfileKey(profileKeyCopy).build(); ContactRecord contactB = contactBuilder(ACI_A, E164_A, "a").profileKey(ByteString.of(profileKeyCopy)).build();
assertEquals(a, b); SignalContactRecord signalContactA = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactA);
assertEquals(a.hashCode(), b.hashCode()); SignalContactRecord signalContactB = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactB);
assertFalse(StorageSyncHelper.profileKeyChanged(update(a, b))); assertEquals(signalContactA, signalContactB);
assertEquals(signalContactA.hashCode(), signalContactB.hashCode());
assertFalse(StorageSyncHelper.profileKeyChanged(update(signalContactA, signalContactB)));
} }
@Test @Test
@@ -147,23 +149,23 @@ public final class StorageSyncHelperTest {
byte[] profileKeyCopy = profileKey.clone(); byte[] profileKeyCopy = profileKey.clone();
profileKeyCopy[0] = 1; profileKeyCopy[0] = 1;
SignalContactRecord a = contactBuilder(1, ACI_A, E164_A, "a").setProfileKey(profileKey).build(); ContactRecord contactA = contactBuilder(ACI_A, E164_A, "a").profileKey(ByteString.of(profileKey)).build();
SignalContactRecord b = contactBuilder(1, ACI_A, E164_A, "a").setProfileKey(profileKeyCopy).build(); ContactRecord contactB = contactBuilder(ACI_A, E164_A, "a").profileKey(ByteString.of(profileKeyCopy)).build();
assertNotEquals(a, b); SignalContactRecord signalContactA = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactA);
assertNotEquals(a.hashCode(), b.hashCode()); SignalContactRecord signalContactB = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactB);
assertTrue(StorageSyncHelper.profileKeyChanged(update(a, b))); assertNotEquals(signalContactA, signalContactB);
assertNotEquals(signalContactA.hashCode(), signalContactB.hashCode());
assertTrue(StorageSyncHelper.profileKeyChanged(update(signalContactA, signalContactB)));
} }
private static SignalContactRecord.Builder contactBuilder(int key, private static ContactRecord.Builder contactBuilder(ACI aci, String e164, String profileName) {
ACI aci, return new ContactRecord.Builder()
String e164, .aci(aci.toString())
String profileName) .e164(e164)
{ .givenName(profileName);
return new SignalContactRecord.Builder(byteArray(key), aci, null)
.setE164(e164)
.setProfileGivenName(profileName);
} }
private static <E extends SignalRecord<?>> StorageRecordUpdate<E> update(E oldRecord, E newRecord) { private static <E extends SignalRecord<?>> StorageRecordUpdate<E> update(E oldRecord, E newRecord) {

View File

@@ -37,7 +37,7 @@ sealed class ServiceId(val libSignalServiceId: LibSignalServiceId) {
@JvmOverloads @JvmOverloads
@JvmStatic @JvmStatic
fun parseOrNull(raw: String?, logFailures: Boolean = true): ServiceId? { fun parseOrNull(raw: String?, logFailures: Boolean = true): ServiceId? {
if (raw == null) { if (raw.isNullOrBlank()) {
return null return null
} }

View File

@@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord
val ContactRecord.signalAci: ServiceId.ACI?
get() = ServiceId.ACI.parseOrNull(this.aci)
val ContactRecord.signalPni: ServiceId.PNI?
get() = ServiceId.PNI.parseOrNull(this.pni)

View File

@@ -1,55 +1,27 @@
package org.whispersystems.signalservice.api.storage package org.whispersystems.signalservice.api.storage
import org.signal.core.util.hasUnknownFields
import org.signal.libsignal.protocol.logging.Log
import org.whispersystems.signalservice.internal.storage.protos.AccountRecord import org.whispersystems.signalservice.internal.storage.protos.AccountRecord
import java.io.IOException import java.io.IOException
class SignalAccountRecord( /**
* Wrapper around a [AccountRecord] to pair it with a [StorageId].
*/
data class SignalAccountRecord(
override val id: StorageId, override val id: StorageId,
override val proto: AccountRecord override val proto: AccountRecord
) : SignalRecord<AccountRecord> { ) : SignalRecord<AccountRecord> {
companion object { companion object {
private val TAG: String = SignalAccountRecord::class.java.simpleName
fun newBuilder(serializedUnknowns: ByteArray?): AccountRecord.Builder { fun newBuilder(serializedUnknowns: ByteArray?): AccountRecord.Builder {
return if (serializedUnknowns != null) { return serializedUnknowns?.let { builderFromUnknowns(it) } ?: AccountRecord.Builder()
parseUnknowns(serializedUnknowns) }
} else {
private fun builderFromUnknowns(serializedUnknowns: ByteArray): AccountRecord.Builder {
return try {
AccountRecord.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
AccountRecord.Builder() AccountRecord.Builder()
} }
} }
private fun parseUnknowns(serializedUnknowns: ByteArray): AccountRecord.Builder {
try {
return AccountRecord.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
Log.w(TAG, "Failed to combine unknown fields!", e)
return AccountRecord.Builder()
}
}
}
fun serializeUnknownFields(): ByteArray? {
return if (proto.hasUnknownFields()) proto.encode() else null
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as SignalAccountRecord
if (id != other.id) return false
if (proto != other.proto) return false
return true
}
override fun hashCode(): Int {
var result = id.hashCode()
result = 31 * result + proto.hashCode()
return result
} }
} }

View File

@@ -5,59 +5,23 @@
package org.whispersystems.signalservice.api.storage package org.whispersystems.signalservice.api.storage
import okio.ByteString.Companion.toByteString
import org.whispersystems.signalservice.internal.storage.protos.CallLinkRecord import org.whispersystems.signalservice.internal.storage.protos.CallLinkRecord
import java.io.IOException import java.io.IOException
/** /**
* A record in storage service that represents a call link that was already created. * Wrapper around a [CallLinkRecord] to pair it with a [StorageId].
*/ */
class SignalCallLinkRecord( data class SignalCallLinkRecord(
override val id: StorageId, override val id: StorageId,
override val proto: CallLinkRecord override val proto: CallLinkRecord
) : SignalRecord<CallLinkRecord> { ) : SignalRecord<CallLinkRecord> {
val rootKey: ByteArray = proto.rootKey.toByteArray()
val adminPassKey: ByteArray = proto.adminPasskey.toByteArray()
val deletionTimestamp: Long = proto.deletedAtTimestampMs
fun isDeleted(): Boolean {
return deletionTimestamp > 0
}
class Builder(rawId: ByteArray, serializedUnknowns: ByteArray?) {
private var id: StorageId = StorageId.forCallLink(rawId)
private var builder: CallLinkRecord.Builder
init {
if (serializedUnknowns != null) {
this.builder = parseUnknowns(serializedUnknowns)
} else {
this.builder = CallLinkRecord.Builder()
}
}
fun setRootKey(rootKey: ByteArray): Builder {
builder.rootKey = rootKey.toByteString()
return this
}
fun setAdminPassKey(adminPasskey: ByteArray): Builder {
builder.adminPasskey = adminPasskey.toByteString()
return this
}
fun setDeletedTimestamp(deletedTimestamp: Long): Builder {
builder.deletedAtTimestampMs = deletedTimestamp
return this
}
fun build(): SignalCallLinkRecord {
return SignalCallLinkRecord(id, builder.build())
}
companion object { companion object {
fun parseUnknowns(serializedUnknowns: ByteArray): CallLinkRecord.Builder { fun newBuilder(serializedUnknowns: ByteArray?): CallLinkRecord.Builder {
return serializedUnknowns?.let { builderFromUnknowns(it) } ?: CallLinkRecord.Builder()
}
private fun builderFromUnknowns(serializedUnknowns: ByteArray): CallLinkRecord.Builder {
return try { return try {
CallLinkRecord.ADAPTER.decode(serializedUnknowns).newBuilder() CallLinkRecord.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) { } catch (e: IOException) {
@@ -65,5 +29,4 @@ class SignalCallLinkRecord(
} }
} }
} }
}
} }

View File

@@ -1,360 +0,0 @@
package org.whispersystems.signalservice.api.storage;
import org.jetbrains.annotations.NotNull;
import org.signal.core.util.ProtoUtil;
import org.signal.libsignal.protocol.logging.Log;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
import org.whispersystems.signalservice.api.util.OptionalUtil;
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord;
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nullable;
import okio.ByteString;
public final class SignalContactRecord implements SignalRecord<ContactRecord> {
private static final String TAG = SignalContactRecord.class.getSimpleName();
private final StorageId id;
private final ContactRecord proto;
private final boolean hasUnknownFields;
private final Optional<ACI> aci;
private final Optional<PNI> pni;
private final Optional<String> e164;
private final Optional<String> profileGivenName;
private final Optional<String> profileFamilyName;
private final Optional<String> systemGivenName;
private final Optional<String> systemFamilyName;
private final Optional<String> systemNickname;
private final Optional<byte[]> profileKey;
private final Optional<String> username;
private final Optional<byte[]> identityKey;
private final Optional<String> nicknameGivenName;
private final Optional<String> nicknameFamilyName;
private final Optional<String> note;
public SignalContactRecord(StorageId id, ContactRecord proto) {
this.id = id;
this.proto = proto;
this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto);
this.aci = OptionalUtil.absentIfEmpty(proto.aci).map(ACI::parseOrNull).map(it -> it.isUnknown() ? null : it);
this.pni = OptionalUtil.absentIfEmpty(proto.pni).map(PNI::parseOrNull).map(it -> it.isUnknown() ? null : it);
this.e164 = OptionalUtil.absentIfEmpty(proto.e164);
this.profileGivenName = OptionalUtil.absentIfEmpty(proto.givenName);
this.profileFamilyName = OptionalUtil.absentIfEmpty(proto.familyName);
this.systemGivenName = OptionalUtil.absentIfEmpty(proto.systemGivenName);
this.systemFamilyName = OptionalUtil.absentIfEmpty(proto.systemFamilyName);
this.systemNickname = OptionalUtil.absentIfEmpty(proto.systemNickname);
this.profileKey = OptionalUtil.absentIfEmpty(proto.profileKey);
this.username = OptionalUtil.absentIfEmpty(proto.username);
this.identityKey = OptionalUtil.absentIfEmpty(proto.identityKey);
this.nicknameGivenName = Optional.ofNullable(proto.nickname).flatMap(n -> OptionalUtil.absentIfEmpty(n.given));
this.nicknameFamilyName = Optional.ofNullable(proto.nickname).flatMap(n -> OptionalUtil.absentIfEmpty(n.family));
this.note = OptionalUtil.absentIfEmpty(proto.note);
}
@Override
public StorageId getId() {
return id;
}
@Override
public ContactRecord getProto() {
return proto;
}
public boolean hasUnknownFields() {
return hasUnknownFields;
}
public byte[] serializeUnknownFields() {
return hasUnknownFields ? proto.encode() : null;
}
public Optional<ACI> getAci() {
return aci;
}
public Optional<PNI> getPni() {
return pni;
}
public Optional<? extends ServiceId> getServiceId() {
if (aci.isPresent()) {
return aci;
} else if (pni.isPresent()) {
return pni;
} else {
return Optional.empty();
}
}
public Optional<String> getNumber() {
return e164;
}
public Optional<String> getProfileGivenName() {
return profileGivenName;
}
public Optional<String> getProfileFamilyName() {
return profileFamilyName;
}
public Optional<String> getSystemGivenName() {
return systemGivenName;
}
public Optional<String> getSystemFamilyName() {
return systemFamilyName;
}
public Optional<String> getSystemNickname() {
return systemNickname;
}
public Optional<String> getNicknameGivenName() {
return nicknameGivenName;
}
public Optional<String> getNicknameFamilyName() {
return nicknameFamilyName;
}
public Optional<String> getNote() {
return note;
}
public Optional<byte[]> getProfileKey() {
return profileKey;
}
public Optional<String> getUsername() {
return username;
}
public Optional<byte[]> getIdentityKey() {
return identityKey;
}
public IdentityState getIdentityState() {
return proto.identityState;
}
public boolean isBlocked() {
return proto.blocked;
}
public boolean isProfileSharingEnabled() {
return proto.whitelisted;
}
public boolean isArchived() {
return proto.archived;
}
public boolean isForcedUnread() {
return proto.markedUnread;
}
public long getMuteUntil() {
return proto.mutedUntilTimestamp;
}
public boolean shouldHideStory() {
return proto.hideStory;
}
public long getUnregisteredTimestamp() {
return proto.unregisteredAtTimestamp;
}
public boolean isHidden() {
return proto.hidden;
}
public boolean isPniSignatureVerified() {
return proto.pniSignatureVerified;
}
/**
* Returns the same record, but stripped of the PNI field. Only used while PNP is in development.
*/
public SignalContactRecord withoutPni() {
return new SignalContactRecord(id, proto.newBuilder().pni("").build());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SignalContactRecord that = (SignalContactRecord) o;
return id.equals(that.id) &&
proto.equals(that.proto);
}
@Override
public int hashCode() {
return Objects.hash(id, proto);
}
public static final class Builder {
private final StorageId id;
private final ContactRecord.Builder builder;
public Builder(byte[] rawId, @Nullable ACI aci, byte[] serializedUnknowns) {
this.id = StorageId.forContact(rawId);
if (serializedUnknowns != null) {
this.builder = parseUnknowns(serializedUnknowns);
} else {
this.builder = new ContactRecord.Builder();
}
builder.aci(aci == null ? "" : aci.toString());
}
public Builder setE164(String e164) {
builder.e164(e164 == null ? "" : e164);
return this;
}
public Builder setPni(PNI pni) {
builder.pni(pni == null ? "" : pni.toStringWithoutPrefix());
return this;
}
public Builder setProfileGivenName(String givenName) {
builder.givenName(givenName == null ? "" : givenName);
return this;
}
public Builder setProfileFamilyName(String familyName) {
builder.familyName(familyName == null ? "" : familyName);
return this;
}
public Builder setSystemGivenName(String givenName) {
builder.systemGivenName(givenName == null ? "" : givenName);
return this;
}
public Builder setSystemFamilyName(String familyName) {
builder.systemFamilyName(familyName == null ? "" : familyName);
return this;
}
public Builder setSystemNickname(String nickname) {
builder.systemNickname(nickname == null ? "" : nickname);
return this;
}
public Builder setProfileKey(byte[] profileKey) {
builder.profileKey(profileKey == null ? ByteString.EMPTY : ByteString.of(profileKey));
return this;
}
public Builder setUsername(String username) {
builder.username(username == null ? "" : username);
return this;
}
public Builder setIdentityKey(byte[] identityKey) {
builder.identityKey(identityKey == null ? ByteString.EMPTY : ByteString.of(identityKey));
return this;
}
public Builder setIdentityState(IdentityState identityState) {
builder.identityState(identityState == null ? IdentityState.DEFAULT : identityState);
return this;
}
public Builder setBlocked(boolean blocked) {
builder.blocked(blocked);
return this;
}
public Builder setProfileSharingEnabled(boolean profileSharingEnabled) {
builder.whitelisted(profileSharingEnabled);
return this;
}
public Builder setArchived(boolean archived) {
builder.archived(archived);
return this;
}
public Builder setForcedUnread(boolean forcedUnread) {
builder.markedUnread(forcedUnread);
return this;
}
public Builder setMuteUntil(long muteUntil) {
builder.mutedUntilTimestamp(muteUntil);
return this;
}
public Builder setHideStory(boolean hideStory) {
builder.hideStory(hideStory);
return this;
}
public Builder setUnregisteredTimestamp(long timestamp) {
builder.unregisteredAtTimestamp(timestamp);
return this;
}
public Builder setHidden(boolean hidden) {
builder.hidden(hidden);
return this;
}
public Builder setPniSignatureVerified(boolean verified) {
builder.pniSignatureVerified(verified);
return this;
}
public Builder setNicknameGivenName(String nicknameGivenName) {
ContactRecord.Name.Builder name = builder.nickname == null ? new ContactRecord.Name.Builder() : builder.nickname.newBuilder();
name.given(nicknameGivenName);
builder.nickname(name.build());
return this;
}
public Builder setNicknameFamilyName(String nicknameFamilyName) {
ContactRecord.Name.Builder name = builder.nickname == null ? new ContactRecord.Name.Builder() : builder.nickname.newBuilder();
name.family(nicknameFamilyName);
builder.nickname(name.build());
return this;
}
public Builder setNote(String note) {
builder.note(note == null ? "" : note);
return this;
}
private static ContactRecord.Builder parseUnknowns(byte[] serializedUnknowns) {
try {
return ContactRecord.ADAPTER.decode(serializedUnknowns).newBuilder();
} catch (IOException e) {
Log.w(TAG, "Failed to combine unknown fields!", e);
return new ContactRecord.Builder();
}
}
public SignalContactRecord build() {
return new SignalContactRecord(id, builder.build());
}
}
}

View File

@@ -0,0 +1,27 @@
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord
import java.io.IOException
/**
* Wrapper around a [ContactRecord] to pair it with a [StorageId].
*/
data class SignalContactRecord(
override val id: StorageId,
override val proto: ContactRecord
) : SignalRecord<ContactRecord> {
companion object {
fun newBuilder(serializedUnknowns: ByteArray?): ContactRecord.Builder {
return serializedUnknowns?.let { builderFromUnknowns(it) } ?: ContactRecord.Builder()
}
private fun builderFromUnknowns(serializedUnknowns: ByteArray): ContactRecord.Builder {
return try {
ContactRecord.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
ContactRecord.Builder()
}
}
}
}

View File

@@ -1,140 +0,0 @@
package org.whispersystems.signalservice.api.storage;
import org.signal.core.util.ProtoUtil;
import org.signal.libsignal.protocol.logging.Log;
import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import okio.ByteString;
public final class SignalGroupV1Record implements SignalRecord<GroupV1Record> {
private static final String TAG = SignalGroupV1Record.class.getSimpleName();
private final StorageId id;
private final GroupV1Record proto;
private final byte[] groupId;
private final boolean hasUnknownFields;
public SignalGroupV1Record(StorageId id, GroupV1Record proto) {
this.id = id;
this.proto = proto;
this.groupId = proto.id.toByteArray();
this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto);
}
@Override
public StorageId getId() {
return id;
}
@Override public GroupV1Record getProto() {
return proto;
}
public boolean hasUnknownFields() {
return hasUnknownFields;
}
public byte[] serializeUnknownFields() {
return hasUnknownFields ? proto.encode() : null;
}
public byte[] getGroupId() {
return groupId;
}
public boolean isBlocked() {
return proto.blocked;
}
public boolean isProfileSharingEnabled() {
return proto.whitelisted;
}
public boolean isArchived() {
return proto.archived;
}
public boolean isForcedUnread() {
return proto.markedUnread;
}
public long getMuteUntil() {
return proto.mutedUntilTimestamp;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SignalGroupV1Record that = (SignalGroupV1Record) o;
return id.equals(that.id) &&
proto.equals(that.proto);
}
@Override
public int hashCode() {
return Objects.hash(id, proto);
}
public static final class Builder {
private final StorageId id;
private final GroupV1Record.Builder builder;
public Builder(byte[] rawId, byte[] groupId, byte[] serializedUnknowns) {
this.id = StorageId.forGroupV1(rawId);
if (serializedUnknowns != null) {
this.builder = parseUnknowns(serializedUnknowns);
} else {
this.builder = new GroupV1Record.Builder();
}
builder.id(ByteString.of(groupId));
}
public Builder setBlocked(boolean blocked) {
builder.blocked(blocked);
return this;
}
public Builder setProfileSharingEnabled(boolean profileSharingEnabled) {
builder.whitelisted(profileSharingEnabled);
return this;
}
public Builder setArchived(boolean archived) {
builder.archived(archived);
return this;
}
public Builder setForcedUnread(boolean forcedUnread) {
builder.markedUnread(forcedUnread);
return this;
}
public Builder setMuteUntil(long muteUntil) {
builder.mutedUntilTimestamp(muteUntil);
return this;
}
private static GroupV1Record.Builder parseUnknowns(byte[] serializedUnknowns) {
try {
return GroupV1Record.ADAPTER.decode(serializedUnknowns).newBuilder();
} catch (IOException e) {
Log.w(TAG, "Failed to combine unknown fields!", e);
return new GroupV1Record.Builder();
}
}
public SignalGroupV1Record build() {
return new SignalGroupV1Record(id, builder.build());
}
}
}

View File

@@ -0,0 +1,27 @@
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record
import java.io.IOException
/**
* Wrapper around a [GroupV1Record] to pair it with a [StorageId].
*/
data class SignalGroupV1Record(
override val id: StorageId,
override val proto: GroupV1Record
) : SignalRecord<GroupV1Record> {
companion object {
fun newBuilder(serializedUnknowns: ByteArray?): GroupV1Record.Builder {
return serializedUnknowns?.let { builderFromUnknowns(it) } ?: GroupV1Record.Builder()
}
private fun builderFromUnknowns(serializedUnknowns: ByteArray): GroupV1Record.Builder {
return try {
GroupV1Record.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
GroupV1Record.Builder()
}
}
}
}

View File

@@ -1,182 +0,0 @@
package org.whispersystems.signalservice.api.storage;
import org.jetbrains.annotations.NotNull;
import org.signal.core.util.ProtoUtil;
import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import okio.ByteString;
public final class SignalGroupV2Record implements SignalRecord<GroupV2Record> {
private static final String TAG = SignalGroupV2Record.class.getSimpleName();
private final StorageId id;
private final GroupV2Record proto;
private final byte[] masterKey;
private final boolean hasUnknownFields;
public SignalGroupV2Record(StorageId id, GroupV2Record proto) {
this.id = id;
this.proto = proto;
this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto);
this.masterKey = proto.masterKey.toByteArray();
}
@Override
public StorageId getId() {
return id;
}
@Override public GroupV2Record getProto() {
return proto;
}
public boolean hasUnknownFields() {
return hasUnknownFields;
}
public byte[] serializeUnknownFields() {
return hasUnknownFields ? proto.encode() : null;
}
public byte[] getMasterKeyBytes() {
return masterKey;
}
public GroupMasterKey getMasterKeyOrThrow() {
try {
return new GroupMasterKey(masterKey);
} catch (InvalidInputException e) {
throw new AssertionError(e);
}
}
public boolean isBlocked() {
return proto.blocked;
}
public boolean isProfileSharingEnabled() {
return proto.whitelisted;
}
public boolean isArchived() {
return proto.archived;
}
public boolean isForcedUnread() {
return proto.markedUnread;
}
public long getMuteUntil() {
return proto.mutedUntilTimestamp;
}
public boolean notifyForMentionsWhenMuted() {
return !proto.dontNotifyForMentionsIfMuted;
}
public boolean shouldHideStory() {
return proto.hideStory;
}
public GroupV2Record.StorySendMode getStorySendMode() {
return proto.storySendMode;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SignalGroupV2Record that = (SignalGroupV2Record) o;
return id.equals(that.id) &&
proto.equals(that.proto);
}
@Override
public int hashCode() {
return Objects.hash(id, proto);
}
public static final class Builder {
private final StorageId id;
private final GroupV2Record.Builder builder;
public Builder(byte[] rawId, GroupMasterKey masterKey, byte[] serializedUnknowns) {
this(rawId, masterKey.serialize(), serializedUnknowns);
}
public Builder(byte[] rawId, byte[] masterKey, byte[] serializedUnknowns) {
this.id = StorageId.forGroupV2(rawId);
if (serializedUnknowns != null) {
this.builder = parseUnknowns(serializedUnknowns);
} else {
this.builder = new GroupV2Record.Builder();
}
builder.masterKey(ByteString.of(masterKey));
}
public Builder setBlocked(boolean blocked) {
builder.blocked(blocked);
return this;
}
public Builder setProfileSharingEnabled(boolean profileSharingEnabled) {
builder.whitelisted(profileSharingEnabled);
return this;
}
public Builder setArchived(boolean archived) {
builder.archived(archived);
return this;
}
public Builder setForcedUnread(boolean forcedUnread) {
builder.markedUnread(forcedUnread);
return this;
}
public Builder setMuteUntil(long muteUntil) {
builder.mutedUntilTimestamp(muteUntil);
return this;
}
public Builder setNotifyForMentionsWhenMuted(boolean value) {
builder.dontNotifyForMentionsIfMuted(!value);
return this;
}
public Builder setHideStory(boolean hideStory) {
builder.hideStory(hideStory);
return this;
}
public Builder setStorySendMode(GroupV2Record.StorySendMode storySendMode) {
builder.storySendMode(storySendMode);
return this;
}
private static GroupV2Record.Builder parseUnknowns(byte[] serializedUnknowns) {
try {
return GroupV2Record.ADAPTER.decode(serializedUnknowns).newBuilder();
} catch (IOException e) {
Log.w(TAG, "Failed to combine unknown fields!", e);
return new GroupV2Record.Builder();
}
}
public SignalGroupV2Record build() {
return new SignalGroupV2Record(id, builder.build());
}
}
}

View File

@@ -0,0 +1,27 @@
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record
import java.io.IOException
/**
* Wrapper around a [GroupV2Record] to pair it with a [StorageId].
*/
data class SignalGroupV2Record(
override val id: StorageId,
override val proto: GroupV2Record
) : SignalRecord<GroupV2Record> {
companion object {
fun newBuilder(serializedUnknowns: ByteArray?): GroupV2Record.Builder {
return serializedUnknowns?.let { builderFromUnknowns(it) } ?: GroupV2Record.Builder()
}
private fun builderFromUnknowns(serializedUnknowns: ByteArray): GroupV2Record.Builder {
return try {
GroupV2Record.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
GroupV2Record.Builder()
}
}
}
}

View File

@@ -1,12 +1,20 @@
package org.whispersystems.signalservice.api.storage package org.whispersystems.signalservice.api.storage
import com.squareup.wire.Message
import org.signal.core.util.hasUnknownFields
import kotlin.reflect.KVisibility import kotlin.reflect.KVisibility
import kotlin.reflect.full.memberProperties import kotlin.reflect.full.memberProperties
/**
* Pairs a storage record with its id. Also contains some useful common methods.
*/
interface SignalRecord<E> { interface SignalRecord<E> {
val id: StorageId val id: StorageId
val proto: E val proto: E
val serializedUnknowns: ByteArray?
get() = (proto as Message<*, *>).takeIf { it.hasUnknownFields() }?.encode()
fun describeDiff(other: SignalRecord<*>): String { fun describeDiff(other: SignalRecord<*>): String {
if (this::class != other::class) { if (this::class != other::class) {
return "Classes are different!" return "Classes are different!"

View File

@@ -64,33 +64,12 @@ object SignalStorageModels {
@JvmStatic @JvmStatic
fun localToRemoteStorageRecord(record: SignalStorageRecord, storageKey: StorageKey): StorageItem { fun localToRemoteStorageRecord(record: SignalStorageRecord, storageKey: StorageKey): StorageItem {
val builder = StorageRecord.Builder()
if (record.proto.contact != null) {
builder.contact(record.proto.contact)
} else if (record.proto.groupV1 != null) {
builder.groupV1(record.proto.groupV1)
} else if (record.proto.groupV2 != null) {
builder.groupV2(record.proto.groupV2)
} else if (record.proto.account != null) {
builder.account(record.proto.account)
} else if (record.proto.storyDistributionList != null) {
builder.storyDistributionList(record.proto.storyDistributionList)
} else if (record.proto.callLink != null) {
builder.callLink(record.proto.callLink)
} else {
throw InvalidStorageWriteError()
}
val remoteRecord = builder.build()
val itemKey = storageKey.deriveItemKey(record.id.raw) val itemKey = storageKey.deriveItemKey(record.id.raw)
val encryptedRecord = SignalStorageCipher.encrypt(itemKey, remoteRecord.encode()) val encryptedRecord = SignalStorageCipher.encrypt(itemKey, record.proto.encode())
return StorageItem.Builder() return StorageItem.Builder()
.key(record.id.raw.toByteString()) .key(record.id.raw.toByteString())
.value_(encryptedRecord.toByteString()) .value_(encryptedRecord.toByteString())
.build() .build()
} }
private class InvalidStorageWriteError : Error()
} }

View File

@@ -1,151 +0,0 @@
package org.whispersystems.signalservice.api.storage;
import org.jetbrains.annotations.NotNull;
import org.signal.core.util.ProtoUtil;
import org.signal.libsignal.protocol.logging.Log;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.internal.storage.protos.StoryDistributionListRecord;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import okio.ByteString;
public class SignalStoryDistributionListRecord implements SignalRecord<StoryDistributionListRecord> {
private static final String TAG = SignalStoryDistributionListRecord.class.getSimpleName();
private final StorageId id;
private final StoryDistributionListRecord proto;
private final boolean hasUnknownFields;
private final List<SignalServiceAddress> recipients;
public SignalStoryDistributionListRecord(StorageId id, StoryDistributionListRecord proto) {
this.id = id;
this.proto = proto;
this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto);
this.recipients = proto.recipientServiceIds
.stream()
.map(ServiceId::parseOrNull)
.filter(Objects::nonNull)
.map(SignalServiceAddress::new)
.collect(Collectors.toList());
}
@Override
public StorageId getId() {
return id;
}
@Override
public StoryDistributionListRecord getProto() {
return proto;
}
public byte[] serializeUnknownFields() {
return hasUnknownFields ? proto.encode() : null;
}
public byte[] getIdentifier() {
return proto.identifier.toByteArray();
}
public String getName() {
return proto.name;
}
public List<SignalServiceAddress> getRecipients() {
return recipients;
}
public long getDeletedAtTimestamp() {
return proto.deletedAtTimestamp;
}
public boolean allowsReplies() {
return proto.allowsReplies;
}
public boolean isBlockList() {
return proto.isBlockList;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SignalStoryDistributionListRecord that = (SignalStoryDistributionListRecord) o;
return id.equals(that.id) &&
proto.equals(that.proto);
}
@Override
public int hashCode() {
return Objects.hash(id, proto);
}
public static final class Builder {
private final StorageId id;
private final StoryDistributionListRecord.Builder builder;
public Builder(byte[] rawId, byte[] serializedUnknowns) {
this.id = StorageId.forStoryDistributionList(rawId);
if (serializedUnknowns != null) {
this.builder = parseUnknowns(serializedUnknowns);
} else {
this.builder = new StoryDistributionListRecord.Builder();
}
}
public Builder setIdentifier(byte[] identifier) {
builder.identifier(ByteString.of(identifier));
return this;
}
public Builder setName(String name) {
builder.name(name);
return this;
}
public Builder setRecipients(List<SignalServiceAddress> recipients) {
builder.recipientServiceIds = recipients.stream()
.map(SignalServiceAddress::getIdentifier)
.collect(Collectors.toList());
return this;
}
public Builder setDeletedAtTimestamp(long deletedAtTimestamp) {
builder.deletedAtTimestamp(deletedAtTimestamp);
return this;
}
public Builder setAllowsReplies(boolean allowsReplies) {
builder.allowsReplies(allowsReplies);
return this;
}
public Builder setIsBlockList(boolean isBlockList) {
builder.isBlockList(isBlockList);
return this;
}
public SignalStoryDistributionListRecord build() {
return new SignalStoryDistributionListRecord(id, builder.build());
}
private static StoryDistributionListRecord.Builder parseUnknowns(byte[] serializedUnknowns) {
try {
return StoryDistributionListRecord.ADAPTER.decode(serializedUnknowns).newBuilder();
} catch (IOException e) {
Log.w(TAG, "Failed to combine unknown fields!", e);
return new StoryDistributionListRecord.Builder();
}
}
}
}

View File

@@ -0,0 +1,24 @@
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.internal.storage.protos.StoryDistributionListRecord
import java.io.IOException
data class SignalStoryDistributionListRecord(
override val id: StorageId,
override val proto: StoryDistributionListRecord
) : SignalRecord<StoryDistributionListRecord> {
companion object {
fun newBuilder(serializedUnknowns: ByteArray?): StoryDistributionListRecord.Builder {
return serializedUnknowns?.let { builderFromUnknowns(it) } ?: StoryDistributionListRecord.Builder()
}
private fun builderFromUnknowns(serializedUnknowns: ByteArray): StoryDistributionListRecord.Builder {
return try {
StoryDistributionListRecord.ADAPTER.decode(serializedUnknowns).newBuilder()
} catch (e: IOException) {
StoryDistributionListRecord.Builder()
}
}
}
}

View File

@@ -7,6 +7,9 @@ import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
/**
* A copy of {@link ManifestRecord.Identifier} that allows us to more easily store unknown types with their integer constant.
*/
public class StorageId { public class StorageId {
private final int type; private final int type;
private final byte[] raw; private final byte[] raw;

View File

@@ -0,0 +1,17 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.storage
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.internal.storage.protos.StoryDistributionListRecord
val StoryDistributionListRecord.recipientServiceAddresses: List<SignalServiceAddress>
get() {
return this.recipientServiceIds
.mapNotNull { ServiceId.parseOrNull(it) }
.map { SignalServiceAddress(it) }
}

View File

@@ -46,6 +46,10 @@ public final class UuidUtil {
return new UUID(high, low); return new UUID(high, low);
} }
public static UUID parseOrThrow(ByteString bytes) {
return parseOrNull(bytes.toByteArray());
}
public static boolean isUuid(String uuid) { public static boolean isUuid(String uuid) {
return uuid != null && UUID_PATTERN.matcher(uuid).matches(); return uuid != null && UUID_PATTERN.matcher(uuid).matches();
} }
@@ -83,6 +87,10 @@ public final class UuidUtil {
return byteArray != null && byteArray.length == 16 ? parseOrThrow(byteArray) : null; return byteArray != null && byteArray.length == 16 ? parseOrThrow(byteArray) : null;
} }
public static UUID parseOrNull(ByteString byteString) {
return parseOrNull(byteString.toByteArray());
}
public static List<UUID> fromByteStrings(Collection<ByteString> byteStringCollection) { public static List<UUID> fromByteStrings(Collection<ByteString> byteStringCollection) {
ArrayList<UUID> result = new ArrayList<>(byteStringCollection.size()); ArrayList<UUID> result = new ArrayList<>(byteStringCollection.size());

View File

@@ -1,8 +1,10 @@
package org.whispersystems.signalservice.api.storage; package org.whispersystems.signalservice.api.storage;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.internal.storage.protos.ContactRecord;
import okio.ByteString;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
@@ -14,27 +16,33 @@ public class SignalContactRecordTest {
@Test @Test
public void contacts_with_same_identity_key_contents_are_equal() { public void contacts_with_same_identity_key_contents_are_equal() {
byte[] profileKey = new byte[32]; byte[] identityKey = new byte[32];
byte[] profileKeyCopy = profileKey.clone(); byte[] identityKeyCopy = identityKey.clone();
SignalContactRecord a = contactBuilder(1, ACI_A, E164_A, "a").setIdentityKey(profileKey).build(); ContactRecord contactA = contactBuilder(ACI_A, E164_A, "a").identityKey(ByteString.of(identityKey)).build();
SignalContactRecord b = contactBuilder(1, ACI_A, E164_A, "a").setIdentityKey(profileKeyCopy).build(); ContactRecord contactB = contactBuilder(ACI_A, E164_A, "a").identityKey(ByteString.of(identityKeyCopy)).build();
assertEquals(a, b); SignalContactRecord signalContactA = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactA);
assertEquals(a.hashCode(), b.hashCode()); SignalContactRecord signalContactB = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactB);
assertEquals(signalContactA, signalContactB);
assertEquals(signalContactA.hashCode(), signalContactB.hashCode());
} }
@Test @Test
public void contacts_with_different_identity_key_contents_are_not_equal() { public void contacts_with_different_identity_key_contents_are_not_equal() {
byte[] profileKey = new byte[32]; byte[] identityKey = new byte[32];
byte[] profileKeyCopy = profileKey.clone(); byte[] identityKeyCopy = identityKey.clone();
profileKeyCopy[0] = 1; identityKeyCopy[0] = 1;
SignalContactRecord a = contactBuilder(1, ACI_A, E164_A, "a").setIdentityKey(profileKey).build(); ContactRecord contactA = contactBuilder(ACI_A, E164_A, "a").identityKey(ByteString.of(identityKey)).build();
SignalContactRecord b = contactBuilder(1, ACI_A, E164_A, "a").setIdentityKey(profileKeyCopy).build(); ContactRecord contactB = contactBuilder(ACI_A, E164_A, "a").identityKey(ByteString.of(identityKeyCopy)).build();
assertNotEquals(a, b); SignalContactRecord signalContactA = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactA);
assertNotEquals(a.hashCode(), b.hashCode()); SignalContactRecord signalContactB = new SignalContactRecord(StorageId.forContact(byteArray(1)), contactB);
assertNotEquals(signalContactA, signalContactB);
assertNotEquals(signalContactA.hashCode(), signalContactB.hashCode());
} }
private static byte[] byteArray(int a) { private static byte[] byteArray(int a) {
@@ -46,13 +54,9 @@ public class SignalContactRecordTest {
return bytes; return bytes;
} }
private static SignalContactRecord.Builder contactBuilder(int key, private static ContactRecord.Builder contactBuilder(ACI serviceId, String e164, String givenName) {
ACI serviceId, return new ContactRecord.Builder()
String e164, .e164(e164)
String givenName) .givenName(givenName);
{
return new SignalContactRecord.Builder(byteArray(key), serviceId, null)
.setE164(e164)
.setProfileGivenName(givenName);
} }
} }