Clean up message processing locks.

This commit is contained in:
Greyson Parrelli
2021-06-15 10:07:41 -04:00
committed by Cody Henthorne
parent 0d0ee753df
commit c0eac5564c
18 changed files with 76 additions and 167 deletions

View File

@@ -1,64 +0,0 @@
package org.thoughtcrime.securesms.crypto;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.concurrent.locks.ReentrantLock;
/**
* An implementation of {@link SignalSessionLock} that effectively re-uses our database lock.
*/
public enum DatabaseSessionLock implements SignalSessionLock {
INSTANCE;
public static final long NO_OWNER = -1;
private static final ReentrantLock LEGACY_LOCK = new ReentrantLock();
private volatile long ownerThreadId = NO_OWNER;
@Override
public Lock acquire() {
LEGACY_LOCK.lock();
return LEGACY_LOCK::unlock;
// TODO [greyson][db] Revisit after improving database locking
// SQLiteDatabase db = DatabaseFactory.getInstance(ApplicationDependencies.getApplication()).getRawDatabase();
//
// if (db.isDbLockedByCurrentThread()) {
// return () -> {};
// }
//
// db.beginTransaction();
//
// ownerThreadId = Thread.currentThread().getId();
//
// return () -> {
// ownerThreadId = -1;
// db.setTransactionSuccessful();
// db.endTransaction();
// };
}
/**
* Important: Only truly useful for debugging. Do not rely on this for functionality. There's tiny
* windows where this state might not be fully accurate.
*
* @return True if it's likely that some other thread owns this lock, and it's not you.
*/
public boolean isLikelyHeldByOtherThread() {
long ownerThreadId = this.ownerThreadId;
return ownerThreadId != -1 && ownerThreadId == Thread.currentThread().getId();
}
/**
* Important: Only truly useful for debugging. Do not rely on this for functionality. There's a
* tiny window where a thread may still own the lock, but the state we track around it has been
* cleared.
*
* @return The ID of the thread that likely owns this lock, or {@link #NO_OWNER} if no one owns it.
*/
public long getLikeyOwnerThreadId() {
return ownerThreadId;
}
}

View File

@@ -0,0 +1,21 @@
package org.thoughtcrime.securesms.crypto;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.concurrent.locks.ReentrantLock;
/**
* An implementation of {@link SignalSessionLock} that is backed by a {@link ReentrantLock}.
*/
public enum ReentrantSessionLock implements SignalSessionLock {
INSTANCE;
private static final ReentrantLock LOCK = new ReentrantLock();
@Override
public Lock acquire() {
LOCK.lock();
return LOCK::unlock;
}
}

View File

@@ -18,7 +18,7 @@ public final class SenderKeyUtil {
* Clears the state for a sender key session we created. It will naturally get re-created when it is next needed, rotating the key.
*/
public static void rotateOurKey(@NonNull Context context, @NonNull DistributionId distributionId) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
new SignalSenderKeyStore(context).deleteAllFor(Recipient.self().getId(), distributionId);
DatabaseFactory.getSenderKeySharedDatabase(context).deleteAllFor(distributionId);
}
@@ -35,7 +35,7 @@ public final class SenderKeyUtil {
* Deletes all stored state around session keys. Should only really be used when the user is re-registering.
*/
public static void clearAllState(@NonNull Context context) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
new SignalSenderKeyStore(context).deleteAll();
DatabaseFactory.getSenderKeySharedDatabase(context).deleteAll();
}

View File

@@ -4,14 +4,12 @@ import android.content.Context;
import androidx.annotation.NonNull;
import org.thoughtcrime.securesms.crypto.DatabaseSessionLock;
import org.thoughtcrime.securesms.database.DatabaseFactory;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.push.DistributionId;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.groups.state.SenderKeyRecord;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.Collection;
import java.util.Set;
@@ -25,6 +23,8 @@ import javax.annotation.Nullable;
*/
public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
private static final Object LOCK = new Object();
private final Context context;
public SignalSenderKeyStore(@NonNull Context context) {
@@ -33,7 +33,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
@Override
public void storeSenderKey(@NonNull SignalProtocolAddress sender, @NonNull UUID distributionId, @NonNull SenderKeyRecord record) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
RecipientId recipientId = RecipientId.fromExternalPush(sender.getName());
DatabaseFactory.getSenderKeyDatabase(context).store(recipientId, sender.getDeviceId(), DistributionId.from(distributionId), record);
}
@@ -41,7 +41,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
@Override
public @Nullable SenderKeyRecord loadSenderKey(@NonNull SignalProtocolAddress sender, @NonNull UUID distributionId) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
RecipientId recipientId = RecipientId.fromExternalPush(sender.getName());
return DatabaseFactory.getSenderKeyDatabase(context).load(recipientId, sender.getDeviceId(), DistributionId.from(distributionId));
}
@@ -49,21 +49,21 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
@Override
public Set<SignalProtocolAddress> getSenderKeySharedWith(DistributionId distributionId) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
return DatabaseFactory.getSenderKeySharedDatabase(context).getSharedWith(distributionId);
}
}
@Override
public void markSenderKeySharedWith(DistributionId distributionId, Collection<SignalProtocolAddress> addresses) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getSenderKeySharedDatabase(context).markAsShared(distributionId, addresses);
}
}
@Override
public void clearSenderKeySharedWith(DistributionId distributionId, Collection<SignalProtocolAddress> addresses) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getSenderKeySharedDatabase(context).delete(distributionId, addresses);
}
}
@@ -72,7 +72,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
* Removes all sender key session state for all devices for the provided recipient-distributionId pair.
*/
public void deleteAllFor(@NonNull RecipientId recipientId, @NonNull DistributionId distributionId) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getSenderKeyDatabase(context).deleteAllFor(recipientId, distributionId);
}
}
@@ -81,7 +81,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
* Deletes all sender key session state.
*/
public void deleteAll() {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getSenderKeyDatabase(context).deleteAll();
}
}

View File

@@ -5,7 +5,6 @@ import android.content.Context;
import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.crypto.DatabaseSessionLock;
import org.thoughtcrime.securesms.crypto.IdentityKeyUtil;
import org.thoughtcrime.securesms.crypto.SessionUtil;
import org.thoughtcrime.securesms.database.DatabaseFactory;
@@ -21,7 +20,6 @@ import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.IdentityKeyStore;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.concurrent.TimeUnit;
@@ -49,7 +47,7 @@ public class TextSecureIdentityKeyStore implements IdentityKeyStore {
}
public @NonNull SaveResult saveIdentity(SignalProtocolAddress address, IdentityKey identityKey, boolean nonBlockingApproval) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
IdentityDatabase identityDatabase = DatabaseFactory.getIdentityDatabase(context);
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
Optional<IdentityRecord> identityRecord = identityDatabase.getIdentity(recipientId);
@@ -96,7 +94,7 @@ public class TextSecureIdentityKeyStore implements IdentityKeyStore {
@Override
public boolean isTrustedIdentity(SignalProtocolAddress address, IdentityKey identityKey, Direction direction) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
IdentityDatabase identityDatabase = DatabaseFactory.getIdentityDatabase(context);
RecipientId ourRecipientId = Recipient.self().getId();

View File

@@ -5,14 +5,12 @@ import android.content.Context;
import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.crypto.DatabaseSessionLock;
import org.thoughtcrime.securesms.database.DatabaseFactory;
import org.whispersystems.libsignal.InvalidKeyIdException;
import org.whispersystems.libsignal.state.PreKeyRecord;
import org.whispersystems.libsignal.state.PreKeyStore;
import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import org.whispersystems.libsignal.state.SignedPreKeyStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.List;
@@ -21,6 +19,8 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@SuppressWarnings("unused")
private static final String TAG = Log.tag(TextSecurePreKeyStore.class);
private static final Object LOCK = new Object();
@NonNull
private final Context context;
@@ -30,7 +30,7 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
PreKeyRecord preKeyRecord = DatabaseFactory.getPreKeyDatabase(context).getPreKey(preKeyId);
if (preKeyRecord == null) throw new InvalidKeyIdException("No such key: " + preKeyId);
@@ -40,7 +40,7 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
SignedPreKeyRecord signedPreKeyRecord = DatabaseFactory.getSignedPreKeyDatabase(context).getSignedPreKey(signedPreKeyId);
if (signedPreKeyRecord == null) throw new InvalidKeyIdException("No such signed prekey: " + signedPreKeyId);
@@ -50,21 +50,21 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public List<SignedPreKeyRecord> loadSignedPreKeys() {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
return DatabaseFactory.getSignedPreKeyDatabase(context).getAllSignedPreKeys();
}
}
@Override
public void storePreKey(int preKeyId, PreKeyRecord record) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getPreKeyDatabase(context).insertPreKey(preKeyId, record);
}
}
@Override
public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
DatabaseFactory.getSignedPreKeyDatabase(context).insertSignedPreKey(signedPreKeyId, record);
}
}

View File

@@ -5,7 +5,6 @@ import android.content.Context;
import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.crypto.DatabaseSessionLock;
import org.thoughtcrime.securesms.database.DatabaseFactory;
import org.thoughtcrime.securesms.database.SessionDatabase;
import org.thoughtcrime.securesms.database.SessionDatabase.RecipientDevice;
@@ -16,7 +15,6 @@ import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.Collections;
import java.util.List;
@@ -26,6 +24,8 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
private static final String TAG = Log.tag(TextSecureSessionStore.class);
private static final Object LOCK = new Object();
@NonNull private final Context context;
public TextSecureSessionStore(@NonNull Context context) {
@@ -34,7 +34,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public SessionRecord loadSession(@NonNull SignalProtocolAddress address) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
@@ -49,7 +49,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
List<RecipientDevice> ids = addresses.stream()
.map(address -> new RecipientDevice(RecipientId.fromExternalPush(address.getName()), address.getDeviceId()))
.collect(Collectors.toList());
@@ -68,7 +68,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
RecipientId id = RecipientId.fromExternalPush(address.getName());
DatabaseFactory.getSessionDatabase(context).store(id, address.getDeviceId(), record);
}
@@ -76,7 +76,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public boolean containsSession(SignalProtocolAddress address) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
@@ -92,7 +92,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void deleteSession(SignalProtocolAddress address) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
DatabaseFactory.getSessionDatabase(context).delete(recipientId, address.getDeviceId());
@@ -104,7 +104,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void deleteAllSessions(String name) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) {
RecipientId recipientId = RecipientId.fromExternalPush(name);
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipientId);
@@ -114,7 +114,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<Integer> getSubDeviceSessions(String name) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) {
RecipientId recipientId = RecipientId.fromExternalPush(name);
return DatabaseFactory.getSessionDatabase(context).getSubDevices(recipientId);
@@ -127,7 +127,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void archiveSession(SignalProtocolAddress address) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
archiveSession(recipientId, address.getDeviceId());
@@ -136,7 +136,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveSession(@NonNull RecipientId recipientId, int deviceId) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
SessionRecord session = DatabaseFactory.getSessionDatabase(context).load(recipientId, deviceId);
if (session != null) {
session.archiveCurrentState();
@@ -146,7 +146,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(recipientId);
@@ -164,7 +164,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveAllSessions() {
try (SignalSessionLock.Lock unused = DatabaseSessionLock.INSTANCE.acquire()) {
synchronized (LOCK) {
List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAll();
for (SessionDatabase.SessionRow row : sessions) {