Basic client usage of CDSHv2.

This provides a basic (read: useful-for-development-yet-broken) client
usage of CDSHv2.
This commit is contained in:
Greyson Parrelli
2022-04-11 19:59:17 -04:00
parent b0e7b49056
commit d3096c56cb
13 changed files with 457 additions and 17 deletions

View File

@@ -0,0 +1,90 @@
package org.thoughtcrime.securesms.database
import android.content.ContentValues
import android.content.Context
import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.logging.Log
import org.signal.core.util.requireNonNullString
import org.signal.core.util.select
import org.signal.core.util.update
/**
* Keeps track of the numbers we've previously queried CDS for.
*
* This is important for rate-limiting: our rate-limiting strategy hinges on keeping
* an accurate history of numbers we've queried so that we're only "charged" for
* querying new numbers.
*/
class CdsDatabase(context: Context, databaseHelper: SignalDatabase) : Database(context, databaseHelper) {
companion object {
private val TAG = Log.tag(CdsDatabase::class.java)
const val TABLE_NAME = "cds"
private const val ID = "_id"
const val E164 = "e164"
private const val LAST_SEEN_AT = "last_seen_at"
const val CREATE_TABLE = """
CREATE TABLE $TABLE_NAME (
$ID INTEGER PRIMARY KEY,
$E164 TEXT NOT NULL UNIQUE ON CONFLICT IGNORE,
$LAST_SEEN_AT INTEGER DEFAULT 0
)
"""
}
fun getAllE164s(): Set<String> {
val e164s: MutableSet<String> = mutableSetOf()
readableDatabase
.select(E164)
.from(TABLE_NAME)
.run()
.use { cursor ->
while (cursor.moveToNext()) {
e164s += cursor.requireNonNullString(E164)
}
}
return e164s
}
/**
* @param newE164s The newly-added E164s that we hadn't previously queried for.
* @param seenE164s The E164s that were seen in either the system contacts or recipients table.
* This should be a superset of [newE164s]
*
*/
fun updateAfterCdsQuery(newE164s: Set<String>, seenE164s: Set<String>) {
val lastSeen = System.currentTimeMillis()
writableDatabase.beginTransaction()
try {
val insertValues: List<ContentValues> = newE164s.map { contentValuesOf(E164 to it) }
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
.forEach { writableDatabase.execSQL(it.where, it.whereArgs) }
for (e164 in seenE164s) {
writableDatabase
.update(TABLE_NAME)
.values(LAST_SEEN_AT to lastSeen)
.where("$E164 = ?", e164)
.run()
}
writableDatabase.setTransactionSuccessful()
} finally {
writableDatabase.endTransaction()
}
}
fun clearAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
}
}

View File

@@ -25,6 +25,8 @@ import org.signal.core.util.requireInt
import org.signal.core.util.requireLong
import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireString
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.InvalidKeyException
import org.signal.libsignal.zkgroup.InvalidInputException
@@ -237,7 +239,8 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
val CREATE_INDEXS = arrayOf(
"CREATE INDEX IF NOT EXISTS recipient_group_type_index ON $TABLE_NAME ($GROUP_TYPE);",
"CREATE UNIQUE INDEX IF NOT EXISTS recipient_pni_index ON $TABLE_NAME ($PNI_COLUMN)"
"CREATE UNIQUE INDEX IF NOT EXISTS recipient_pni_index ON $TABLE_NAME ($PNI_COLUMN)",
"CREATE INDEX IF NOT EXISTS recipient_service_id_profile_key ON $TABLE_NAME ($SERVICE_ID, $PROFILE_KEY) WHERE $SERVICE_ID NOT NULL AND $PROFILE_KEY NOT NULL"
)
private val RECIPIENT_PROJECTION: Array<String> = arrayOf(
@@ -499,6 +502,28 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
}
}
fun getAllServiceIdProfileKeyPairs(): Map<ServiceId, ProfileKey> {
val serviceIdToProfileKey: MutableMap<ServiceId, ProfileKey> = mutableMapOf()
readableDatabase
.select(SERVICE_ID, PROFILE_KEY)
.from(TABLE_NAME)
.where("$SERVICE_ID NOT NULL AND $PROFILE_KEY NOT NULL")
.run()
.use { cursor ->
while (cursor.moveToNext()) {
val serviceId: ServiceId? = ServiceId.parseOrNull(cursor.requireString(SERVICE_ID))
val profileKey: ProfileKey? = ProfileKeyUtil.profileKeyOrNull(cursor.requireString(PROFILE_KEY))
if (serviceId != null && profileKey != null) {
serviceIdToProfileKey[serviceId] = profileKey
}
}
}
return serviceIdToProfileKey
}
private fun fetchRecipient(serviceId: ServiceId?, e164: String?, highTrust: Boolean, changeSelf: Boolean): RecipientFetch {
val byE164 = e164?.let { getByE164(it) } ?: Optional.empty()
val byAci = serviceId?.let { getByServiceId(it) } ?: Optional.empty()
@@ -2087,6 +2112,27 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
return aciMap
}
/**
* A dumb implementation of processing CDSv2 results. Suitable only for testing and not for actual use.
*/
fun bulkProcessCdsV2Result(mapping: Map<String, CdsV2Result>): Set<RecipientId> {
val ids: MutableSet<RecipientId> = mutableSetOf()
val db = writableDatabase
db.beginTransaction()
try {
for ((e164, result) in mapping) {
ids += getAndPossiblyMerge(result.bestServiceId(), e164, true)
}
db.setTransactionSuccessful()
} finally {
db.endTransaction()
}
return ids
}
fun getUninvitedRecipientsForInsights(): List<RecipientId> {
val results: MutableList<RecipientId> = LinkedList()
val args = arrayOf((System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)).toString())
@@ -2876,6 +2922,19 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
}
}
/**
* Should only be used for debugging! A very destructive action that clears all known serviceIds.
*/
fun debugClearServiceIds() {
writableDatabase
.update(TABLE_NAME)
.values(
SERVICE_ID to null,
PNI_COLUMN to null
)
.run()
}
fun getRecord(context: Context, cursor: Cursor): RecipientRecord {
return getRecord(context, cursor, ID)
}
@@ -3431,4 +3490,17 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
val serviceId: ServiceId? = null,
val e164: String? = null
)
data class CdsV2Result(
val pni: PNI,
val aci: ACI?
) {
fun bestServiceId(): ServiceId {
return if (aci != null) {
aci
} else {
pni
}
}
}
}

View File

@@ -71,6 +71,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val donationReceiptDatabase: DonationReceiptDatabase = DonationReceiptDatabase(context, this)
val distributionListDatabase: DistributionListDatabase = DistributionListDatabase(context, this)
val storySendsDatabase: StorySendsDatabase = StorySendsDatabase(context, this)
val cdsDatabase: CdsDatabase = CdsDatabase(context, this)
override fun onOpen(db: net.zetetic.database.sqlcipher.SQLiteDatabase) {
db.enableWriteAheadLogging()
@@ -105,6 +106,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
db.execSQL(ReactionDatabase.CREATE_TABLE)
db.execSQL(DonationReceiptDatabase.CREATE_TABLE)
db.execSQL(StorySendsDatabase.CREATE_TABLE)
db.execSQL(CdsDatabase.CREATE_TABLE)
executeStatements(db, SearchDatabase.CREATE_TABLE)
executeStatements(db, RemappedRecordsDatabase.CREATE_TABLE)
executeStatements(db, MessageSendLogDatabase.CREATE_TABLE)
@@ -328,6 +330,11 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val avatarPicker: AvatarPickerDatabase
get() = instance!!.avatarPickerDatabase
@get:JvmStatic
@get:JvmName("cds")
val cds: CdsDatabase
get() = instance!!.cdsDatabase
@get:JvmStatic
@get:JvmName("chatColors")
val chatColors: ChatColorsDatabase
@@ -338,6 +345,11 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val distributionLists: DistributionListDatabase
get() = instance!!.distributionListDatabase
@get:JvmStatic
@get:JvmName("donationReceipts")
val donationReceipts: DonationReceiptDatabase
get() = instance!!.donationReceiptDatabase
@get:JvmStatic
@get:JvmName("drafts")
val drafts: DraftDatabase
@@ -474,19 +486,14 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val stickers: StickerDatabase
get() = instance!!.stickerDatabase
@get:JvmStatic
@get:JvmName("unknownStorageIds")
val unknownStorageIds: UnknownStorageIdDatabase
get() = instance!!.storageIdDatabase
@get:JvmStatic
@get:JvmName("donationReceipts")
val donationReceipts: DonationReceiptDatabase
get() = instance!!.donationReceiptDatabase
@get:JvmStatic
@get:JvmName("storySends")
val storySends: StorySendsDatabase
get() = instance!!.storySendsDatabase
@get:JvmStatic
@get:JvmName("unknownStorageIds")
val unknownStorageIds: UnknownStorageIdDatabase
get() = instance!!.storageIdDatabase
}
}

View File

@@ -193,8 +193,9 @@ object SignalDatabaseMigrations {
private const val STORY_TYPE_AND_DISTRIBUTION = 137
private const val CLEAN_DELETED_DISTRIBUTION_LISTS = 138
private const val REMOVE_KNOWN_UNKNOWNS = 139
private const val CDS_V2 = 140
const val DATABASE_VERSION = 139
const val DATABASE_VERSION = 140
@JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
@@ -2509,6 +2510,19 @@ object SignalDatabaseMigrations {
val count: Int = db.delete("storage_key", "type <= ?", SqlUtil.buildArgs(4))
Log.i(TAG, "Cleaned up $count invalid unknown records.")
}
if (oldVersion < CDS_V2) {
db.execSQL("CREATE INDEX IF NOT EXISTS recipient_service_id_profile_key ON recipient (uuid, profile_key) WHERE uuid NOT NULL AND profile_key NOT NULL")
db.execSQL(
"""
CREATE TABLE cds (
_id INTEGER PRIMARY KEY,
e164 TEXT NOT NULL UNIQUE ON CONFLICT IGNORE,
last_seen_at INTEGER DEFAULT 0
)
"""
)
}
}
@JvmStatic