diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 00ebc32e9..57041a684 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -398,34 +398,38 @@ public class KeysController { } private List parseDeviceId(String deviceId, Account account) { + return parseDeviceId(deviceId) + .map(id -> account.getDevice(id).map(List::of).orElse(List.of())) + .orElseGet(account::getDevices); + } + + private Optional parseDeviceId(final String deviceId) { if (deviceId.equals("*")) { - return account.getDevices(); + return Optional.empty(); } try { - byte id = Byte.parseByte(deviceId); - return account.getDevice(id).map(List::of).orElse(List.of()); + return Optional.of(Byte.parseByte(deviceId)); } catch (NumberFormatException e) { throw new WebApplicationException(Response.status(422).build()); } } + private String getPreKeysLimiterKey( final Account account, final AuthenticatedDevice authenticatedDevice, final ServiceIdentifier targetIdentifier, final Account targetAccount, final String targetDeviceId) { - final String targetRegistrationId = targetDeviceId.equals("*") - ? "*" - : String.valueOf( - parseDeviceId(targetDeviceId, targetAccount).getFirst().getRegistrationId(targetIdentifier.identityType())); - - return String.format("%s.%s__%s.%s.%s", + final Optional parsedTargetDeviceId = parseDeviceId(targetDeviceId); + final Optional targetRegistrationId = parsedTargetDeviceId + .flatMap(targetAccount::getDevice) + .map(device -> device.getRegistrationId(targetIdentifier.identityType())); + return RateLimitKeys.preKeyLimiterKey( account.getUuid(), authenticatedDevice.deviceId(), - targetIdentifier.uuid(), - targetDeviceId, - targetRegistrationId - ); + targetIdentifier, + parsedTargetDeviceId, targetRegistrationId); } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitKeys.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitKeys.java new file mode 100644 index 000000000..4684ab37a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitKeys.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.controllers; + +import java.util.Optional; +import java.util.UUID; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; + +/// Helper methods to centrally define rate-limit key formats used in multiple places +public class RateLimitKeys { + + /// Key for rate-limiting a device's pre-keys + /// + /// @param sourceAci the account identifier of the authenticated device fetching the pre-keys + /// @param sourceDeviceId the deviceId of the authenticated device fetching pre-keys + /// @param targetIdentifier the [ServiceIdentifier] of the target account + /// @param targetDeviceId the deviceId of the target device, empty if fetching all device pre-keys + /// @param targetRegistrationId the registrationId of the target device, empty if fetching all device pre-keys + /// @return the rate-limit key + public static String preKeyLimiterKey( + final UUID sourceAci, + final byte sourceDeviceId, + final ServiceIdentifier targetIdentifier, + final Optional targetDeviceId, + final Optional targetRegistrationId) { + return String.format("%s.%s__%s.%s.%s", + sourceAci, + sourceDeviceId, + targetIdentifier.uuid(), + targetDeviceId.map(String::valueOf).orElse("*"), + targetRegistrationId.map(String::valueOf).orElse("*")); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index 8c4ca7ebc..6f5ad167a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.grpc; import io.grpc.StatusRuntimeException; import java.util.List; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.BiFunction; @@ -31,6 +32,7 @@ import org.signal.libsignal.protocol.kem.KEMPublicKey; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.controllers.RateLimitKeys; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; @@ -96,18 +98,28 @@ public class KeysGrpcService extends SimpleKeysGrpc.KeysImplBase { final ServiceIdentifier targetIdentifier = ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier()); + + final Optional maybeTargetAccount = accountsManager.getByServiceIdentifier(targetIdentifier); + final byte deviceId = request.hasDeviceId() ? DeviceIdUtil.validate(request.getDeviceId()) : KeysGrpcHelper.ALL_DEVICES; - final String rateLimitKey = authenticatedDevice.accountIdentifier() + "." + - authenticatedDevice.deviceId() + "__" + - targetIdentifier.uuid() + "." + - deviceId; + final Optional targetRegistrationId = maybeTargetAccount + .filter(_ -> request.hasDeviceId()) + .flatMap(targetAccount -> targetAccount.getDevice(deviceId)) + .map(device -> device.getRegistrationId(targetIdentifier.identityType())); + + final String rateLimitKey = RateLimitKeys.preKeyLimiterKey( + authenticatedDevice.accountIdentifier(), + authenticatedDevice.deviceId(), + targetIdentifier, + Optional.ofNullable(request.hasDeviceId() ? deviceId : null), + targetRegistrationId); rateLimiters.getPreKeysLimiter().validate(rateLimitKey); - return accountsManager.getByServiceIdentifier(targetIdentifier) + return maybeTargetAccount .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager)) .map(accountPreKeyBundles -> GetPreKeysResponse.newBuilder() .setPreKeys(accountPreKeyBundles) diff --git a/service/src/main/proto/org/signal/chat/keys.proto b/service/src/main/proto/org/signal/chat/keys.proto index 62699af21..cbc2b9bd9 100644 --- a/service/src/main/proto/org/signal/chat/keys.proto +++ b/service/src/main/proto/org/signal/chat/keys.proto @@ -14,6 +14,7 @@ import "google/protobuf/empty.proto"; import "org/signal/chat/common.proto"; import "org/signal/chat/errors.proto"; import "org/signal/chat/require.proto"; +import "org/signal/chat/tag.proto"; // Provides methods for working with pre-keys. service Keys { @@ -152,7 +153,7 @@ message GetPreKeysResponse { // Either the target account was not found, no active device with the given // ID (if specified) was found on the target account. - errors.NotFound target_not_found = 2; + errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"]; } } @@ -163,10 +164,10 @@ message GetPreKeysAnonymousResponse { // Either the target account was not found, no active device with the given // ID (if specified) was found on the target account. - errors.NotFound target_not_found = 2; + errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"]; // The provided unidentified authorization credential was invalid - errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3; + errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3 [(tag.reason) = "failed_unidentified_authorization"]; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java index 80d9d7b4c..4f688ac39 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java @@ -15,7 +15,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; @@ -38,6 +37,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.Mock; import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcSignedPreKey; @@ -599,28 +599,48 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() + final GetPreKeysRequest.Builder builder = GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) - .build()) - .build())); - verifyNoInteractions(accountsManager); + .setUuid(UUIDUtil.toByteString(targetAccountId))); + if (!allDevices) { + builder.setDeviceId(targetDeviceId); + } + assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(builder.build())); + verify(preKeysRateLimiter).validate(expectedRateLimitKey); } }