Separate session store for PNI.

This commit is contained in:
Greyson Parrelli
2022-02-02 11:53:26 -05:00
parent e8ad1e8ed1
commit c2ca899a7c
17 changed files with 320 additions and 337 deletions

View File

@@ -1,39 +0,0 @@
package org.thoughtcrime.securesms.crypto;
import androidx.annotation.NonNull;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.ecc.ECPublicKey;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
public class SessionUtil {
public static boolean hasSession(@NonNull RecipientId id) {
SignalProtocolAddress axolotlAddress = new SignalProtocolAddress(Recipient.resolved(id).requireServiceId(), SignalServiceAddress.DEFAULT_DEVICE_ID);
return ApplicationDependencies.getProtocolStore().aci().containsSession(axolotlAddress);
}
public static void archiveSiblingSessions(SignalProtocolAddress address) {
ApplicationDependencies.getProtocolStore().aci().sessions().archiveSiblingSessions(address);
}
public static void archiveAllSessions() {
ApplicationDependencies.getProtocolStore().aci().sessions().archiveAllSessions();
}
public static void archiveSession(RecipientId recipientId, int deviceId) {
ApplicationDependencies.getProtocolStore().aci().sessions().archiveSession(recipientId, deviceId);
}
public static boolean ratchetKeyMatches(@NonNull Recipient recipient, int deviceId, @NonNull ECPublicKey ratchetKey) {
SignalProtocolAddress address = new SignalProtocolAddress(recipient.resolve().requireServiceId(), deviceId);
SessionRecord session = ApplicationDependencies.getProtocolStore().aci().loadSession(address);
return session.currentRatchetKeyMatches(ratchetKey);
}
}

View File

@@ -6,8 +6,6 @@ import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.crypto.IdentityKeyUtil;
import org.thoughtcrime.securesms.crypto.SessionUtil;
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore.SaveResult;
import org.thoughtcrime.securesms.database.IdentityDatabase;
import org.thoughtcrime.securesms.database.IdentityDatabase.VerifiedStatus;
@@ -15,17 +13,16 @@ import org.thoughtcrime.securesms.database.SignalDatabase;
import org.thoughtcrime.securesms.database.identity.IdentityRecordList;
import org.thoughtcrime.securesms.database.model.IdentityRecord;
import org.thoughtcrime.securesms.database.model.IdentityStoreRecord;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.keyvalue.SignalStore;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.thoughtcrime.securesms.util.IdentityUtil;
import org.thoughtcrime.securesms.util.LRUCache;
import org.whispersystems.libsignal.IdentityKey;
import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.IdentityKeyStore;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.push.AccountIdentifier;
import java.util.ArrayList;
import java.util.List;
@@ -90,7 +87,7 @@ public class SignalBaseIdentityKeyStore {
cache.save(address.getName(), recipientId, identityKey, verifiedStatus, false, System.currentTimeMillis(), nonBlockingApproval);
IdentityUtil.markIdentityUpdate(context, recipientId);
SessionUtil.archiveSiblingSessions(address);
ApplicationDependencies.getProtocolStore().aci().sessions().archiveSiblingSessions(address);
SignalDatabase.senderKeyShared().deleteAllFor(recipientId);
return SaveResult.UPDATE;
}

View File

@@ -15,6 +15,7 @@ import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.AccountIdentifier;
import java.util.List;
import java.util.Objects;
@@ -27,16 +28,16 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
private static final Object LOCK = new Object();
@NonNull private final Context context;
private final AccountIdentifier accountId;
public TextSecureSessionStore(@NonNull Context context) {
this.context = context;
public TextSecureSessionStore(@NonNull AccountIdentifier accountId) {
this.accountId = accountId;
}
@Override
public SessionRecord loadSession(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
SessionRecord sessionRecord = SignalDatabase.sessions().load(address);
SessionRecord sessionRecord = SignalDatabase.sessions().load(accountId, address);
if (sessionRecord == null) {
Log.w(TAG, "No existing session information found for " + address);
@@ -50,7 +51,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException {
synchronized (LOCK) {
List<SessionRecord> sessionRecords = SignalDatabase.sessions().load(addresses);
List<SessionRecord> sessionRecords = SignalDatabase.sessions().load(accountId, addresses);
if (sessionRecords.size() != addresses.size()) {
String message = "Mismatch! Asked for " + addresses.size() + " sessions, but only found " + sessionRecords.size() + "!";
@@ -69,14 +70,14 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
synchronized (LOCK) {
SignalDatabase.sessions().store(address, record);
SignalDatabase.sessions().store(accountId, address, record);
}
}
@Override
public boolean containsSession(SignalProtocolAddress address) {
synchronized (LOCK) {
SessionRecord sessionRecord = SignalDatabase.sessions().load(address);
SessionRecord sessionRecord = SignalDatabase.sessions().load(accountId, address);
return sessionRecord != null &&
sessionRecord.hasSenderChain() &&
@@ -88,7 +89,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
public void deleteSession(SignalProtocolAddress address) {
synchronized (LOCK) {
Log.w(TAG, "Deleting session for " + address);
SignalDatabase.sessions().delete(address);
SignalDatabase.sessions().delete(accountId, address);
}
}
@@ -96,14 +97,14 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
public void deleteAllSessions(String name) {
synchronized (LOCK) {
Log.w(TAG, "Deleting all sessions for " + name);
SignalDatabase.sessions().deleteAllFor(name);
SignalDatabase.sessions().deleteAllFor(accountId, name);
}
}
@Override
public List<Integer> getSubDeviceSessions(String name) {
synchronized (LOCK) {
return SignalDatabase.sessions().getSubDevices(name);
return SignalDatabase.sessions().getSubDevices(accountId, name);
}
}
@@ -111,7 +112,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
synchronized (LOCK) {
return SignalDatabase.sessions()
.getAllFor(addressNames)
.getAllFor(accountId, addressNames)
.stream()
.filter(row -> isActive(row.getRecord()))
.map(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()))
@@ -122,10 +123,10 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void archiveSession(SignalProtocolAddress address) {
synchronized (LOCK) {
SessionRecord session = SignalDatabase.sessions().load(address);
SessionRecord session = SignalDatabase.sessions().load(accountId, address);
if (session != null) {
session.archiveCurrentState();
SignalDatabase.sessions().store(address, session);
SignalDatabase.sessions().store(accountId, address, session);
}
}
}
@@ -146,7 +147,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
List<SessionDatabase.SessionRow> sessions = SignalDatabase.sessions().getAllFor(address.getName());
List<SessionDatabase.SessionRow> sessions = SignalDatabase.sessions().getAllFor(accountId, address.getName());
for (SessionDatabase.SessionRow row : sessions) {
if (row.getDeviceId() != address.getDeviceId()) {
@@ -159,7 +160,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
public void archiveAllSessions() {
synchronized (LOCK) {
List<SessionDatabase.SessionRow> sessions = SignalDatabase.sessions().getAll();
List<SessionDatabase.SessionRow> sessions = SignalDatabase.sessions().getAll(accountId);
for (SessionDatabase.SessionRow row : sessions) {
row.getRecord().archiveCurrentState();
@@ -173,8 +174,4 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
record.hasSenderChain() &&
record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
}
private static boolean isValidRegistrationId(int registrationId) {
return (registrationId & 0x3fff) == registrationId;
}
}