Add batch identity checks to stories and share/forward flows.

This commit is contained in:
Cody Henthorne
2022-07-26 16:31:55 -04:00
parent 87cb2d6bf8
commit a7a5f2e8c6
9 changed files with 418 additions and 24 deletions

View File

@@ -23,10 +23,11 @@ class ContactSearchMediator(
selectionLimits: SelectionLimits,
displayCheckBox: Boolean,
mapStateToConfiguration: (ContactSearchState) -> ContactSearchConfiguration,
private val contactSelectionPreFilter: (View?, Set<ContactSearchKey>) -> Set<ContactSearchKey> = { _, s -> s }
private val contactSelectionPreFilter: (View?, Set<ContactSearchKey>) -> Set<ContactSearchKey> = { _, s -> s },
performSafetyNumberChecks: Boolean = true
) {
private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository())).get(ContactSearchViewModel::class.java)
private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository(), performSafetyNumberChecks)).get(ContactSearchViewModel::class.java)
init {

View File

@@ -21,7 +21,9 @@ import org.whispersystems.signalservice.api.util.Preconditions
*/
class ContactSearchViewModel(
private val selectionLimits: SelectionLimits,
private val contactSearchRepository: ContactSearchRepository
private val contactSearchRepository: ContactSearchRepository,
private val performSafetyNumberChecks: Boolean,
private val safetyNumberRepository: SafetyNumberRepository = SafetyNumberRepository(),
) : ViewModel() {
private val disposables = CompositeDisposable()
@@ -75,6 +77,10 @@ class ContactSearchViewModel(
return@subscribe
}
if (performSafetyNumberChecks) {
safetyNumberRepository.batchSafetyNumberCheck(newSelectionEntries)
}
selectionStore.update { state -> state + newSelectionEntries }
}
}
@@ -123,9 +129,13 @@ class ContactSearchViewModel(
controller.value?.onDataInvalidated()
}
class Factory(private val selectionLimits: SelectionLimits, private val repository: ContactSearchRepository) : ViewModelProvider.Factory {
class Factory(
private val selectionLimits: SelectionLimits,
private val repository: ContactSearchRepository,
private val performSafetyNumberChecks: Boolean
) : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: Class<T>): T {
return modelClass.cast(ContactSearchViewModel(selectionLimits, repository)) as T
return modelClass.cast(ContactSearchViewModel(selectionLimits, repository, performSafetyNumberChecks)) as T
}
}
}

View File

@@ -0,0 +1,114 @@
package org.thoughtcrime.securesms.contacts.paged
import androidx.annotation.VisibleForTesting
import io.reactivex.rxjava3.core.Single
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.IdentityUtil
import org.thoughtcrime.securesms.util.Stopwatch
import org.whispersystems.signalservice.api.services.ProfileService
import org.whispersystems.signalservice.internal.ServiceResponseProcessor
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse
import java.util.concurrent.TimeUnit
/**
* Generic repository for interacting with safety numbers and fetch new ones.
*/
class SafetyNumberRepository(
private val profileService: ProfileService = ApplicationDependencies.getProfileService(),
private val aciIdentityStore: SignalIdentityKeyStore = ApplicationDependencies.getProtocolStore().aci().identities()
) {
private val recentlyFetched: MutableMap<RecipientId, Long> = HashMap()
fun batchSafetyNumberCheck(newSelectionEntries: List<ContactSearchKey>) {
SignalExecutors.UNBOUNDED.execute { batchSafetyNumberCheckSync(newSelectionEntries) }
}
@Suppress("UNCHECKED_CAST")
@VisibleForTesting
fun batchSafetyNumberCheckSync(newSelectionEntries: List<ContactSearchKey>, now: Long = System.currentTimeMillis(), batchSize: Int = MAX_BATCH_SIZE) {
val stopwatch = Stopwatch("batch-snc")
val recipientIds: Set<RecipientId> = newSelectionEntries.flattenToRecipientIds()
stopwatch.split("recipient-ids")
val recentIds = recentlyFetched.filter { (_, timestamp) -> (now - timestamp) < RECENT_TIME_WINDOW }.keys
val recipients = Recipient.resolvedList(recipientIds - recentIds).filter { it.hasServiceId() }
stopwatch.split("recipient-resolve")
if (recipients.isNotEmpty()) {
Log.i(TAG, "Checking on ${recipients.size} identities...")
val requests: List<Single<List<IdentityCheckResponse.AciIdentityPair>>> = recipients.chunked(batchSize) { it.createBatchRequestSingle() }
stopwatch.split("requests")
val aciKeyPairs: List<IdentityCheckResponse.AciIdentityPair> = Single.zip(requests) { responses ->
responses
.map { it as List<IdentityCheckResponse.AciIdentityPair> }
.flatten()
}.blockingGet()
stopwatch.split("batch-fetches")
if (aciKeyPairs.isEmpty()) {
Log.d(TAG, "No identity key mismatches")
} else {
aciKeyPairs
.filter { it.aci != null && it.identityKey != null }
.forEach { IdentityUtil.saveIdentity(it.aci.toString(), it.identityKey) }
}
recentlyFetched += recipients.associate { it.id to now }
stopwatch.split("saving-identities")
}
stopwatch.stop(TAG)
}
private fun List<ContactSearchKey>.flattenToRecipientIds(): Set<RecipientId> {
return this
.map {
when (it) {
is ContactSearchKey.RecipientSearchKey.KnownRecipient -> {
val recipient = Recipient.resolved(it.recipientId)
if (recipient.isGroup) {
recipient.participantIds
} else {
listOf(it.recipientId)
}
}
is ContactSearchKey.RecipientSearchKey.Story -> Recipient.resolved(it.recipientId).participantIds
else -> throw AssertionError("Invalid contact selection $it")
}
}
.flatten()
.toMutableSet()
.apply { remove(Recipient.self().id) }
}
private fun List<Recipient>.createBatchRequestSingle(): Single<List<IdentityCheckResponse.AciIdentityPair>> {
return profileService
.performIdentityCheck(
mapNotNull { r ->
val identityRecord = aciIdentityStore.getIdentityRecord(r.id)
if (identityRecord.isPresent) {
r.requireServiceId() to identityRecord.get().identityKey
} else {
null
}
}.associate { it }
)
.map { ServiceResponseProcessor.DefaultProcessor(it).resultOrThrow.aciKeyPairs ?: emptyList() }
.onErrorReturn { t ->
Log.w(TAG, "Unable to fetch identities", t)
emptyList()
}
}
companion object {
private val TAG = Log.tag(SafetyNumberRepository::class.java)
private val RECENT_TIME_WINDOW = TimeUnit.SECONDS.toMillis(30)
private const val MAX_BATCH_SIZE = 1000
}
}

View File

@@ -16,6 +16,8 @@ public final class IdentityRecordList {
public static final IdentityRecordList EMPTY = new IdentityRecordList(Collections.emptyList());
private static final long DEFAULT_UNTRUSTED_WINDOW = TimeUnit.SECONDS.toMillis(5);
private final List<IdentityRecord> identityRecords;
private final boolean isVerified;
private final boolean isUnverified;
@@ -78,7 +80,7 @@ public final class IdentityRecordList {
continue;
}
if (isUntrusted(identityRecord)) {
if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) {
return true;
}
}
@@ -87,10 +89,14 @@ public final class IdentityRecordList {
}
public @NonNull List<IdentityRecord> getUntrustedRecords() {
return getUntrustedRecords(DEFAULT_UNTRUSTED_WINDOW);
}
public @NonNull List<IdentityRecord> getUntrustedRecords(long untrustedWindowMillis) {
List<IdentityRecord> results = new ArrayList<>(identityRecords.size());
for (IdentityRecord identityRecord : identityRecords) {
if (isUntrusted(identityRecord)) {
if (isUntrusted(identityRecord, untrustedWindowMillis)) {
results.add(identityRecord);
}
}
@@ -102,7 +108,7 @@ public final class IdentityRecordList {
List<Recipient> untrusted = new ArrayList<>(identityRecords.size());
for (IdentityRecord identityRecord : identityRecords) {
if (isUntrusted(identityRecord)) {
if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) {
untrusted.add(Recipient.resolved(identityRecord.getRecipientId()));
}
}
@@ -134,9 +140,9 @@ public final class IdentityRecordList {
return unverified;
}
private static boolean isUntrusted(@NonNull IdentityRecord identityRecord) {
private static boolean isUntrusted(@NonNull IdentityRecord identityRecord, long untrustedWindowMillis) {
return !identityRecord.isApprovedNonBlocking() &&
System.currentTimeMillis() - identityRecord.getTimestamp() < TimeUnit.SECONDS.toMillis(5);
System.currentTimeMillis() - identityRecord.getTimestamp() < untrustedWindowMillis;
}
}

View File

@@ -10,6 +10,7 @@ import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.database.model.IdentityRecord
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.recipients.Recipient
import java.util.concurrent.TimeUnit
object UntrustedRecords {
@@ -41,7 +42,7 @@ object UntrustedRecords {
}
.flatten()
return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).untrustedRecords
return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).getUntrustedRecords(TimeUnit.SECONDS.toMillis(30))
}
class UntrustedRecordsException(val untrustedRecords: List<IdentityRecord>, val destinations: Set<ContactSearchKey.RecipientSearchKey>) : Throwable()

View File

@@ -62,11 +62,11 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment(
val contactRecycler: RecyclerView = view.findViewById(R.id.contact_recycler)
mediator = ContactSearchMediator(
this,
contactRecycler,
FeatureFlags.shareSelectionLimit(),
true,
{ state ->
fragment = this,
recyclerView = contactRecycler,
selectionLimits = FeatureFlags.shareSelectionLimit(),
displayCheckBox = true,
mapStateToConfiguration = { state ->
ContactSearchConfiguration.build {
query = state.query
@@ -77,7 +77,8 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment(
)
)
}
}
},
performSafetyNumberChecks = false
)
mediator.getSelectionState().observe(viewLifecycleOwner) { state ->