Migrate the session table to be keyed off of libsignal IDs.

This commit is contained in:
Greyson Parrelli
2021-08-19 14:11:14 -04:00
committed by Alex Hart
parent c24dfdce34
commit 6618d696e4
8 changed files with 142 additions and 156 deletions

View File

@@ -7,7 +7,6 @@ import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.database.DatabaseFactory;
import org.thoughtcrime.securesms.database.SessionDatabase;
import org.thoughtcrime.securesms.database.SessionDatabase.RecipientDevice;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.libsignal.NoSessionException;
@@ -16,9 +15,7 @@ import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class TextSecureSessionStore implements SignalServiceSessionStore {
@@ -35,8 +32,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public SessionRecord loadSession(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(address);
if (sessionRecord == null) {
Log.w(TAG, "No existing session information found.");
@@ -50,11 +46,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException {
synchronized (LOCK) {
List<RecipientDevice> ids = addresses.stream()
.map(address -> new RecipientDevice(RecipientId.fromExternalPush(address.getName()), address.getDeviceId()))
.collect(Collectors.toList());
List<SessionRecord> sessionRecords = DatabaseFactory.getSessionDatabase(context).load(ids);
List<SessionRecord> sessionRecords = DatabaseFactory.getSessionDatabase(context).load(addresses);
if (sessionRecords.size() != addresses.size()) {
String message = "Mismatch! Asked for " + addresses.size() + " sessions, but only found " + sessionRecords.size() + "!";
@@ -69,96 +61,76 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
synchronized (LOCK) {
RecipientId id = RecipientId.fromExternalPush(address.getName());
DatabaseFactory.getSessionDatabase(context).store(id, address.getDeviceId(), record);
DatabaseFactory.getSessionDatabase(context).store(address, record);
}
}
@Override
public boolean containsSession(SignalProtocolAddress address) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(address);
return sessionRecord != null &&
sessionRecord.hasSenderChain() &&
sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
} else {
return false;
}
return sessionRecord != null &&
sessionRecord.hasSenderChain() &&
sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
}
}
@Override
public void deleteSession(SignalProtocolAddress address) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
DatabaseFactory.getSessionDatabase(context).delete(recipientId, address.getDeviceId());
} else {
Log.w(TAG, "Tried to delete session for " + address.toString() + ", but none existed!");
}
DatabaseFactory.getSessionDatabase(context).delete(address);
}
}
@Override
public void deleteAllSessions(String name) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) {
RecipientId recipientId = RecipientId.fromExternalPush(name);
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipientId);
}
DatabaseFactory.getSessionDatabase(context).deleteAllFor(name);
}
}
@Override
public List<Integer> getSubDeviceSessions(String name) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) {
RecipientId recipientId = RecipientId.fromExternalPush(name);
return DatabaseFactory.getSessionDatabase(context).getSubDevices(recipientId);
} else {
Log.w(TAG, "Tried to get sub device sessions for " + name + ", but none existed!");
return Collections.emptyList();
}
return DatabaseFactory.getSessionDatabase(context).getSubDevices(name);
}
}
@Override
public void archiveSession(SignalProtocolAddress address) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
archiveSession(recipientId, address.getDeviceId());
SessionRecord session = DatabaseFactory.getSessionDatabase(context).load(address);
if (session != null) {
session.archiveCurrentState();
DatabaseFactory.getSessionDatabase(context).store(address, session);
}
}
}
public void archiveSession(@NonNull RecipientId recipientId, int deviceId) {
synchronized (LOCK) {
SessionRecord session = DatabaseFactory.getSessionDatabase(context).load(recipientId, deviceId);
if (session != null) {
session.archiveCurrentState();
DatabaseFactory.getSessionDatabase(context).store(recipientId, deviceId, session);
Recipient recipient = Recipient.resolved(recipientId);
if (recipient.hasUuid()) {
archiveSession(new SignalProtocolAddress(recipient.requireUuid().toString(), deviceId));
}
if (recipient.hasE164()) {
archiveSession(new SignalProtocolAddress(recipient.requireE164(), deviceId));
}
}
}
public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(recipientId);
List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(address.getName());
for (SessionDatabase.SessionRow row : sessions) {
if (row.getDeviceId() != address.getDeviceId()) {
row.getRecord().archiveCurrentState();
storeSession(new SignalProtocolAddress(Recipient.resolved(row.getRecipientId()).requireServiceId(), row.getDeviceId()), row.getRecord());
}
for (SessionDatabase.SessionRow row : sessions) {
if (row.getDeviceId() != address.getDeviceId()) {
row.getRecord().archiveCurrentState();
storeSession(new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), row.getRecord());
}
} else {
Log.w(TAG, "Tried to archive sibling sessions for " + address.toString() + ", but none existed!");
}
}
}
@@ -169,7 +141,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
for (SessionDatabase.SessionRow row : sessions) {
row.getRecord().archiveCurrentState();
storeSession(new SignalProtocolAddress(Recipient.resolved(row.getRecipientId()).requireServiceId(), row.getDeviceId()), row.getRecord());
storeSession(new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), row.getRecord());
}
}
}