Add optimistic locking to account updates

This commit is contained in:
Chris Eager
2021-07-07 11:54:22 -05:00
committed by Jon Chambers
parent 62022c7de1
commit 158d65c6a7
30 changed files with 1397 additions and 399 deletions

View File

@@ -9,7 +9,6 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.basic.BasicCredentials;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.time.Clock;
@@ -118,8 +117,7 @@ public class BaseAccountAuthenticator {
Metrics.summary(DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME, IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isMaster()))
.record(Duration.ofMillis(todayInMillisWithOffset - device.getLastSeen()).toDays());
device.setLastSeen(Util.todayInMillis(clock));
accountsManager.update(account);
accountsManager.updateDevice(account, device.getId(), d -> d.setLastSeen(Util.todayInMillis(clock)));
}
}

View File

@@ -7,15 +7,15 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import java.time.Duration;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
public class CircuitBreakerConfiguration {
@JsonProperty
@@ -39,6 +39,9 @@ public class CircuitBreakerConfiguration {
@Min(1)
private long waitDurationInOpenStateInSeconds = 10;
@JsonProperty
private List<String> ignoredExceptions = Collections.emptyList();
public int getFailureRateThreshold() {
return failureRateThreshold;
@@ -56,6 +59,18 @@ public class CircuitBreakerConfiguration {
return waitDurationInOpenStateInSeconds;
}
public List<Class> getIgnoredExceptions() {
return ignoredExceptions.stream()
.map(name -> {
try {
return Class.forName(name);
} catch (final ClassNotFoundException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
}
@VisibleForTesting
public void setFailureRateThreshold(int failureRateThreshold) {
this.failureRateThreshold = failureRateThreshold;
@@ -76,9 +91,15 @@ public class CircuitBreakerConfiguration {
this.waitDurationInOpenStateInSeconds = seconds;
}
@VisibleForTesting
public void setIgnoredExceptions(final List<String> ignoredExceptions) {
this.ignoredExceptions = ignoredExceptions;
}
public CircuitBreakerConfig toCircuitBreakerConfig() {
return CircuitBreakerConfig.custom()
.failureRateThreshold(getFailureRateThreshold())
.ignoreExceptions(getIgnoredExceptions().toArray(new Class[0]))
.ringBufferSizeInHalfOpenState(getRingBufferSizeInHalfOpenState())
.waitDurationInOpenState(Duration.ofSeconds(getWaitDurationInOpenStateInSeconds()))
.ringBufferSizeInClosedState(getRingBufferSizeInClosedState())

View File

@@ -439,12 +439,12 @@ public class AccountController {
return;
}
device.setApnId(null);
device.setVoipApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
device.setFetchesMessages(false);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setVoipApnId(null);
d.setGcmId(registrationId.getGcmRegistrationId());
d.setFetchesMessages(false);
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@@ -457,11 +457,12 @@ public class AccountController {
public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setGcmId(null);
device.setFetchesMessages(false);
device.setUserAgent("OWA");
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
d.setFetchesMessages(false);
d.setUserAgent("OWA");
});
directoryQueue.refreshRegisteredUser(account);
}
@@ -474,11 +475,12 @@ public class AccountController {
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
device.setApnId(registrationId.getApnRegistrationId());
device.setVoipApnId(registrationId.getVoipRegistrationId());
device.setGcmId(null);
device.setFetchesMessages(false);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(registrationId.getApnRegistrationId());
d.setVoipApnId(registrationId.getVoipRegistrationId());
d.setGcmId(null);
d.setFetchesMessages(false);
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@@ -491,15 +493,16 @@ public class AccountController {
public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setFetchesMessages(false);
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
accounts.update(account);
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setFetchesMessages(false);
if (d.getId() == 1) {
d.setUserAgent("OWI");
} else {
d.setUserAgent("OWP");
}
});
directoryQueue.refreshRegisteredUser(account);
}
@@ -509,18 +512,18 @@ public class AccountController {
@Path("/registration_lock")
public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) {
AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock());
account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
account.setPin(null);
accounts.update(account);
accounts.update(account, a -> {
a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
a.setPin(null);
});
}
@Timed
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Auth Account account) {
account.setRegistrationLock(null, null);
accounts.update(account);
accounts.update(account, a -> a.setRegistrationLock(null, null));
}
@Timed
@@ -531,21 +534,21 @@ public class AccountController {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN set by User-Agent: {}", userAgent);
account.setPin(accountLock.getPin());
account.setRegistrationLock(null, null);
accounts.update(account);
accounts.update(account, a -> {
a.setPin(accountLock.getPin());
a.setRegistrationLock(null, null);
});
}
@Timed
@DELETE
@Path("/pin/")
public void removePin(@Auth Account account, @HeaderParam("User-Agent") String userAgent) {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN removed by User-Agent: {}", userAgent);
account.setPin(null);
accounts.update(account);
accounts.update(account, a -> a.setPin(null));
}
@Timed
@@ -553,8 +556,8 @@ public class AccountController {
@Path("/name/")
public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) {
Account account = disabledPermittedAccount.getAccount();
account.getAuthenticatedDevice().get().setName(deviceName.getDeviceName());
accounts.update(account);
Device device = account.getAuthenticatedDevice().get();
accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName()));
}
@Timed
@@ -572,25 +575,29 @@ public class AccountController {
@Valid AccountAttributes attributes)
{
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
long deviceId = account.getAuthenticatedDevice().get().getId();
device.setFetchesMessages(attributes.getFetchesMessages());
device.setName(attributes.getName());
device.setLastSeen(Util.todayInMillis());
device.setCapabilities(attributes.getCapabilities());
device.setRegistrationId(attributes.getRegistrationId());
device.setUserAgent(userAgent);
account = accounts.update(account, a-> {
setAccountRegistrationLockFromAttributes(account, attributes);
a.getDevice(deviceId).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent);
});
setAccountRegistrationLockFromAttributes(a, attributes);
a.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
a.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
a.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
});
final boolean hasDiscoverabilityChange = (account.isDiscoverableByPhoneNumber() != attributes.isDiscoverableByPhoneNumber());
account.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
accounts.update(account);
if (hasDiscoverabilityChange) {
directoryQueue.refreshRegisteredUser(account);
}

View File

@@ -100,8 +100,7 @@ public class DeviceController {
}
messages.clear(account.getUuid(), deviceId);
account.removeDevice(deviceId);
accounts.update(account);
account = accounts.update(account, a -> a.removeDevice(deviceId));
directoryQueue.refreshRegisteredUser(account);
// ensure any messages that came in after the first clear() are also removed
messages.clear(account.getUuid(), deviceId);
@@ -192,15 +191,16 @@ public class DeviceController {
device.setName(accountAttributes.getName());
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
account.get().addDevice(device);
messages.clear(account.get().getUuid(), device.getId());
accounts.update(account.get());
accounts.update(account.get(), a -> {
device.setId(account.get().getNextDeviceId());
messages.clear(account.get().getUuid(), device.getId());
a.addDevice(device);
});;
pendingDevices.remove(number);
@@ -224,8 +224,8 @@ public class DeviceController {
@Path("/capabilities")
public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) {
assert(account.getAuthenticatedDevice().isPresent());
account.getAuthenticatedDevice().get().setCapabilities(capabilities);
accounts.update(account);
final long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities));
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@@ -104,17 +104,18 @@ public class KeysController {
boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
device.setSignedPreKey(preKeys.getSignedPreKey());
updateAccount = true;
}
if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) {
account.setIdentityKey(preKeys.getIdentityKey());
updateAccount = true;
}
if (updateAccount) {
accounts.update(account);
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> d.setSignedPreKey(preKeys.getSignedPreKey()));
a.setIdentityKey(preKeys.getIdentityKey());
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@@ -200,8 +201,7 @@ public class KeysController {
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
device.setSignedPreKey(signedPreKey);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey));
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);

View File

@@ -156,10 +156,11 @@ public class ProfileController {
response = Optional.of(generateAvatarUploadForm(avatar));
}
account.setProfileName(request.getName());
account.setAvatar(avatar);
account.setCurrentProfileVersion(request.getVersion());
accountsManager.update(account);
accountsManager.update(account, a -> {
a.setProfileName(request.getName());
a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion());
});
if (response.isPresent()) return Response.ok(response).build();
else return Response.ok().build();
@@ -317,8 +318,7 @@ public class ProfileController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}")
public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
account.setProfileName(name.orElse(null));
accountsManager.update(account);
accountsManager.update(account, a -> a.setProfileName(name.orElse(null)));
}
@Deprecated
@@ -382,8 +382,7 @@ public class ProfileController {
.build());
}
account.setAvatar(objectName);
accountsManager.update(account);
accountsManager.update(account, a -> a.setAvatar(objectName));
return profileAvatarUploadAttributes;
}

View File

@@ -110,8 +110,8 @@ public class GCMSender {
Device device = account.get().getDevice(message.getDeviceId()).get();
if (device.getUninstalledFeedbackTimestamp() == 0) {
device.setUninstalledFeedbackTimestamp(Util.todayInMillis());
accountsManager.update(account.get());
accountsManager.updateDevice(account.get(), message.getDeviceId(), d ->
d.setUninstalledFeedbackTimestamp(Util.todayInMillis()));
}
}
@@ -122,15 +122,11 @@ public class GCMSender {
logger.warn(String.format("Actually received 'CanonicalRegistrationId' ::: (canonical=%s), (original=%s)",
result.getCanonicalRegistrationId(), message.getGcmId()));
Optional<Account> account = getAccountForEvent(message);
if (account.isPresent()) {
//noinspection OptionalGetWithoutIsPresent
Device device = account.get().getDevice(message.getDeviceId()).get();
device.setGcmId(result.getCanonicalRegistrationId());
accountsManager.update(account.get());
}
getAccountForEvent(message).ifPresent(account ->
accountsManager.updateDevice(
account,
message.getDeviceId(),
d -> d.setGcmId(result.getCanonicalRegistrationId())));
canonical.mark();
}

View File

@@ -14,11 +14,16 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.security.auth.Subject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
public class Account implements Principal {
@JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class);
@JsonIgnore
private UUID uuid;
@@ -58,12 +63,15 @@ public class Account implements Principal {
@JsonProperty("inCds")
private boolean discoverableByPhoneNumber = true;
@JsonProperty("_ddbV")
private int dynamoDbMigrationVersion;
@JsonIgnore
private Device authenticatedDevice;
@JsonProperty
private int version;
@JsonIgnore
private boolean stale;
public Account() {}
@VisibleForTesting
@@ -75,47 +83,68 @@ public class Account implements Principal {
}
public Optional<Device> getAuthenticatedDevice() {
requireNotStale();
return Optional.ofNullable(authenticatedDevice);
}
public void setAuthenticatedDevice(Device device) {
requireNotStale();
this.authenticatedDevice = device;
}
public UUID getUuid() {
// this is the one method that may be called on a stale account
return uuid;
}
public void setUuid(UUID uuid) {
requireNotStale();
this.uuid = uuid;
}
public void setNumber(String number) {
requireNotStale();
this.number = number;
}
public String getNumber() {
requireNotStale();
return number;
}
public void addDevice(Device device) {
requireNotStale();
this.devices.remove(device);
this.devices.add(device);
}
public void removeDevice(long deviceId) {
requireNotStale();
this.devices.remove(new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "NA", 0, null));
}
public Set<Device> getDevices() {
requireNotStale();
return devices;
}
public Optional<Device> getMasterDevice() {
requireNotStale();
return getDevice(Device.MASTER_ID);
}
public Optional<Device> getDevice(long deviceId) {
requireNotStale();
for (Device device : devices) {
if (device.getId() == deviceId) {
return Optional.of(device);
@@ -126,42 +155,58 @@ public class Account implements Principal {
}
public boolean isGroupsV2Supported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(Device::isGroupsV2Supported);
}
public boolean isStorageSupported() {
requireNotStale();
return devices.stream().anyMatch(device -> device.getCapabilities() != null && device.getCapabilities().isStorage());
}
public boolean isTransferSupported() {
requireNotStale();
return getMasterDevice().map(Device::getCapabilities).map(Device.DeviceCapabilities::isTransfer).orElse(false);
}
public boolean isGv1MigrationSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isGv1Migration());
}
public boolean isSenderKeySupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isSenderKey());
}
public boolean isAnnouncementGroupSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isAnnouncementGroup());
}
public boolean isEnabled() {
requireNotStale();
return getMasterDevice().map(Device::isEnabled).orElse(false);
}
public long getNextDeviceId() {
requireNotStale();
long highestDevice = Device.MASTER_ID;
for (Device device : devices) {
@@ -176,6 +221,8 @@ public class Account implements Principal {
}
public int getEnabledDeviceCount() {
requireNotStale();
int count = 0;
for (Device device : devices) {
@@ -186,22 +233,32 @@ public class Account implements Principal {
}
public boolean isRateLimited() {
requireNotStale();
return true;
}
public Optional<String> getRelay() {
requireNotStale();
return Optional.empty();
}
public void setIdentityKey(String identityKey) {
requireNotStale();
this.identityKey = identityKey;
}
public String getIdentityKey() {
requireNotStale();
return identityKey;
}
public long getLastSeen() {
requireNotStale();
long lastSeen = 0;
for (Device device : devices) {
@@ -214,78 +271,127 @@ public class Account implements Principal {
}
public Optional<String> getCurrentProfileVersion() {
requireNotStale();
return Optional.ofNullable(currentProfileVersion);
}
public void setCurrentProfileVersion(String currentProfileVersion) {
requireNotStale();
this.currentProfileVersion = currentProfileVersion;
}
public String getProfileName() {
requireNotStale();
return name;
}
public void setProfileName(String name) {
requireNotStale();
this.name = name;
}
public String getAvatar() {
requireNotStale();
return avatar;
}
public void setAvatar(String avatar) {
requireNotStale();
this.avatar = avatar;
}
public void setPin(String pin) {
requireNotStale();
this.pin = pin;
}
public void setRegistrationLock(String registrationLock, String registrationLockSalt) {
requireNotStale();
this.registrationLock = registrationLock;
this.registrationLockSalt = registrationLockSalt;
}
public StoredRegistrationLock getRegistrationLock() {
requireNotStale();
return new StoredRegistrationLock(Optional.ofNullable(registrationLock), Optional.ofNullable(registrationLockSalt), Optional.ofNullable(pin), getLastSeen());
}
public Optional<byte[]> getUnidentifiedAccessKey() {
requireNotStale();
return Optional.ofNullable(unidentifiedAccessKey);
}
public void setUnidentifiedAccessKey(byte[] unidentifiedAccessKey) {
requireNotStale();
this.unidentifiedAccessKey = unidentifiedAccessKey;
}
public boolean isUnrestrictedUnidentifiedAccess() {
requireNotStale();
return unrestrictedUnidentifiedAccess;
}
public void setUnrestrictedUnidentifiedAccess(boolean unrestrictedUnidentifiedAccess) {
requireNotStale();
this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess;
}
public boolean isFor(AmbiguousIdentifier identifier) {
requireNotStale();
if (identifier.hasUuid()) return identifier.getUuid().equals(uuid);
else if (identifier.hasNumber()) return identifier.getNumber().equals(number);
else throw new AssertionError();
}
public boolean isDiscoverableByPhoneNumber() {
requireNotStale();
return this.discoverableByPhoneNumber;
}
public void setDiscoverableByPhoneNumber(final boolean discoverableByPhoneNumber) {
requireNotStale();
this.discoverableByPhoneNumber = discoverableByPhoneNumber;
}
public int getDynamoDbMigrationVersion() {
return dynamoDbMigrationVersion;
public int getVersion() {
requireNotStale();
return version;
}
public void setDynamoDbMigrationVersion(int dynamoDbMigrationVersion) {
this.dynamoDbMigrationVersion = dynamoDbMigrationVersion;
public void setVersion(int version) {
requireNotStale();
this.version = version;
}
public void markStale() {
stale = true;
}
private void requireNotStale() {
assert !stale;
//noinspection ConstantConditions
if (stale) {
logger.error("Accessor called on stale account", new RuntimeException());
}
}
// Principal implementation

View File

@@ -7,7 +7,7 @@ public interface AccountStore {
boolean create(Account account);
void update(Account account);
void update(Account account) throws ContestedOptimisticLockException;
Optional<Account> get(String number);

View File

@@ -13,6 +13,7 @@ import com.codahale.metrics.Timer.Context;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
@@ -22,10 +23,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
public class Accounts implements AccountStore {
public static final String ID = "id";
public static final String UID = "uuid";
public static final String ID = "id";
public static final String UID = "uuid";
public static final String NUMBER = "number";
public static final String DATA = "data";
public static final String DATA = "data";
public static final String VERSION = "version";
private static final ObjectMapper mapper = SystemMapper.getMapper();
@@ -50,15 +52,19 @@ public class Accounts implements AccountStore {
public boolean create(Account account) {
return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
try (Timer.Context ignored = createTimer.time()) {
UUID uuid = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET data = EXCLUDED.data RETURNING uuid")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapTo(UUID.class)
.findOnly();
final Map<String, Object> resultMap = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET " + DATA + " = EXCLUDED.data, " + VERSION + " = accounts.version + 1 RETURNING uuid, version")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapToMap()
.findOnly();
final UUID uuid = (UUID) resultMap.get(UID);
final int version = (int) resultMap.get(VERSION);
boolean isNew = uuid.equals(account.getUuid());
account.setUuid(uuid);
account.setVersion(version);
return isNew;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
@@ -67,13 +73,23 @@ public class Accounts implements AccountStore {
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = updateTimer.time()) {
handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + UID + " = :uuid")
final int newVersion = account.getVersion() + 1;
int rowsModified = handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json), " + VERSION + " = :newVersion WHERE " + UID + " = :uuid AND " + VERSION + " = :version")
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.bind("version", account.getVersion())
.bind("newVersion", newVersion)
.execute();
if (rowsModified == 0) {
throw new ContestedOptimisticLockException();
}
account.setVersion(newVersion);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}

View File

@@ -30,16 +30,20 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse;
public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountStore {
@@ -49,8 +53,8 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
static final String ATTR_ACCOUNT_E164 = "P";
// account, serialized to JSON
static final String ATTR_ACCOUNT_DATA = "D";
static final String ATTR_MIGRATION_VERSION = "V";
// internal version for optimistic locking
static final String ATTR_VERSION = "V";
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient;
@@ -122,11 +126,19 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
ByteBuffer actualAccountUuid = phoneNumberConstraintCancellationReason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer();
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final int version = get(account.getUuid()).get().getVersion();
account.setVersion(version);
update(account);
return false;
}
if ("TransactionConflict".equals(accountCancellationReason.code())) {
// this should only happen during concurrent update()s for an account migration
throw new ContestedOptimisticLockException();
}
// this shouldnt happen
throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e));
}
@@ -146,7 +158,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid),
ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
ATTR_MIGRATION_VERSION, AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
ATTR_VERSION, AttributeValues.fromInt(account.getVersion())))
.build())
.build();
}
@@ -172,28 +184,44 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
UPDATE_TIMER.record(() -> {
UpdateItemRequest updateItemRequest;
try {
updateItemRequest = UpdateItemRequest.builder()
.tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression("SET #data = :data, #version = :version")
.conditionExpression("attribute_exists(#number)")
.updateExpression("SET #data = :data ADD #version :version_increment")
.conditionExpression("attribute_exists(#number) AND #version = :version")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164,
"#data", ATTR_ACCOUNT_DATA,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)))
.returnValues(ReturnValue.UPDATED_NEW)
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
client.updateItem(updateItemRequest);
try {
UpdateItemResponse response = client.updateItem(updateItemRequest);
account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1));
} catch (final TransactionConflictException e) {
throw new ContestedOptimisticLockException();
} catch (final ConditionalCheckFailedException e) {
// the exception doesnt give details about which condition failed,
// but we can infer it was an optimistic locking failure if the UUID is known
throw get(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e;
}
});
}
@@ -343,9 +371,9 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
.conditionExpression("attribute_not_exists(#uuid) OR (attribute_exists(#uuid) AND #version < :version)")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion()))));
":version", AttributeValues.fromInt(account.getVersion()))));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, accountPut).build();
@@ -395,6 +423,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class);
account.setNumber(item.get(ATTR_ACCOUNT_E164).s());
account.setUuid(UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer()));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
return account;

View File

@@ -26,6 +26,8 @@ import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import net.logstash.logback.argument.StructuredArguments;
import org.apache.commons.lang3.StringUtils;
@@ -40,7 +42,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
public class AccountsManager {
@@ -119,7 +120,6 @@ public class AccountsManager {
this.mapper = SystemMapper.getMapper();
this.migrationComparisonMapper = mapper.copy();
migrationComparisonMapper.addMixIn(Account.class, AccountComparisonMixin.class);
migrationComparisonMapper.addMixIn(Device.class, DeviceComparisonMixin.class);
this.dynamicConfigurationManager = dynamicConfigurationManager;
@@ -169,25 +169,86 @@ public class AccountsManager {
}
}
public void update(Account account) {
public Account update(Account account, Consumer<Account> updater) {
final Account updatedAccount;
try (Timer.Context ignored = updateTimer.time()) {
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
redisSet(account);
databaseUpdate(account);
updater.accept(account);
{
// optimistically increment version
final int originalVersion = account.getVersion();
account.setVersion(originalVersion + 1);
redisSet(account);
account.setVersion(originalVersion);
}
final UUID uuid = account.getUuid();
updatedAccount = updateWithRetries(account, updater, this::databaseUpdate, () -> databaseGet(uuid).get());
if (dynamoWriteEnabled()) {
runSafelyAndRecordMetrics(() -> {
try {
dynamoUpdate(account);
} catch (final ConditionalCheckFailedException e) {
dynamoCreate(account);
final Optional<Account> dynamoAccount = dynamoGet(uuid);
if (dynamoAccount.isPresent()) {
updater.accept(dynamoAccount.get());
Account dynamoUpdatedAccount = updateWithRetries(dynamoAccount.get(),
updater,
this::dynamoUpdate,
() -> dynamoGet(uuid).get());
return Optional.of(dynamoUpdatedAccount);
}
return true;
}, Optional.of(account.getUuid()), true,
(databaseSuccess, dynamoSuccess) -> Optional.empty(), // both values are always true
return Optional.empty();
}, Optional.of(uuid), Optional.of(updatedAccount),
this::compareAccounts,
"update");
}
// set the cache again, so that all updates are coalesced
redisSet(updatedAccount);
}
return updatedAccount;
}
private Account updateWithRetries(Account account, Consumer<Account> updater, Consumer<Account> persister, Supplier<Account> retriever) {
final int maxTries = 10;
int tries = 0;
while (tries < maxTries) {
try {
persister.accept(account);
final Account updatedAccount;
try {
updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
updatedAccount.setUuid(account.getUuid());
} catch (final IOException e) {
// this should really, truly, never happen
throw new IllegalArgumentException(e);
}
account.markStale();
return updatedAccount;
} catch (final ContestedOptimisticLockException e) {
tries++;
account = retriever.get();
updater.accept(account);
}
}
throw new OptimisticLockRetryLimitExceededException();
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> a.getDevice(deviceId).ifPresent(deviceUpdater));
}
public Optional<Account> get(AmbiguousIdentifier identifier) {
@@ -445,6 +506,10 @@ public class AccountsManager {
return Optional.of("number");
}
if (databaseAccount.getVersion() != dynamoAccount.getVersion()) {
return Optional.of("version");
}
if (!Objects.equals(databaseAccount.getIdentityKey(), dynamoAccount.getIdentityKey())) {
return Optional.of("identityKey");
}
@@ -566,13 +631,6 @@ public class AccountsManager {
.collect(Collectors.joining(" -> "));
}
private static abstract class AccountComparisonMixin extends Account {
@JsonIgnore
private int dynamoDbMigrationVersion;
}
private static abstract class DeviceComparisonMixin extends Device {
@JsonIgnore

View File

@@ -0,0 +1,13 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class ContestedOptimisticLockException extends RuntimeException {
public ContestedOptimisticLockException() {
super(null, null, true, false);
}
}

View File

@@ -0,0 +1,10 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class OptimisticLockRetryLimitExceededException extends RuntimeException {
}

View File

@@ -5,20 +5,20 @@
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@@ -47,36 +47,42 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
for (Account account : chunkAccounts) {
boolean update = false;
for (Device device : account.getDevices()) {
if (device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis())
{
if (device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis()) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
final Set<Device> devices = account.getDevices();
for (Device device : devices) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
expired.mark();
} else {
device.setUninstalledFeedbackTimestamp(0);
recovered.mark();
}
update = true;
}
}
if (update) {
accountsManager.update(account);
account = accountsManager.update(account, a -> {
for (Device device: a.getDevices()) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
} else {
device.setUninstalledFeedbackTimestamp(0);
}
}
}
});
directoryUpdateAccounts.add(account);
}
}
@@ -85,4 +91,13 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts);
}
}
private boolean deviceNeedsUpdate(final Device device) {
return device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
private boolean deviceExpired(final Device device) {
return device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
}

View File

@@ -27,6 +27,7 @@ public class AccountRowMapper implements RowMapper<Account> {
Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class);
account.setNumber(resultSet.getString(Accounts.NUMBER));
account.setUuid(UUID.fromString(resultSet.getString(Accounts.UID)));
account.setVersion(resultSet.getInt(Accounts.VERSION));
return account;
} catch (IOException e) {
throw new SQLException(e);