Match KeyController's rate limit key in KeysGrpcService

This commit is contained in:
ravi-signal
2026-03-24 13:26:25 -05:00
committed by GitHub
parent 4d24c814cc
commit fb84066f09
5 changed files with 108 additions and 36 deletions

View File

@@ -398,34 +398,38 @@ public class KeysController {
}
private List<Device> parseDeviceId(String deviceId, Account account) {
return parseDeviceId(deviceId)
.map(id -> account.getDevice(id).map(List::of).orElse(List.of()))
.orElseGet(account::getDevices);
}
private Optional<Byte> parseDeviceId(final String deviceId) {
if (deviceId.equals("*")) {
return account.getDevices();
return Optional.empty();
}
try {
byte id = Byte.parseByte(deviceId);
return account.getDevice(id).map(List::of).orElse(List.of());
return Optional.of(Byte.parseByte(deviceId));
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
private String getPreKeysLimiterKey(
final Account account,
final AuthenticatedDevice authenticatedDevice,
final ServiceIdentifier targetIdentifier,
final Account targetAccount,
final String targetDeviceId) {
final String targetRegistrationId = targetDeviceId.equals("*")
? "*"
: String.valueOf(
parseDeviceId(targetDeviceId, targetAccount).getFirst().getRegistrationId(targetIdentifier.identityType()));
return String.format("%s.%s__%s.%s.%s",
final Optional<Byte> parsedTargetDeviceId = parseDeviceId(targetDeviceId);
final Optional<Integer> targetRegistrationId = parsedTargetDeviceId
.flatMap(targetAccount::getDevice)
.map(device -> device.getRegistrationId(targetIdentifier.identityType()));
return RateLimitKeys.preKeyLimiterKey(
account.getUuid(),
authenticatedDevice.deviceId(),
targetIdentifier.uuid(),
targetDeviceId,
targetRegistrationId
);
targetIdentifier,
parsedTargetDeviceId, targetRegistrationId);
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.Optional;
import java.util.UUID;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
/// Helper methods to centrally define rate-limit key formats used in multiple places
public class RateLimitKeys {
/// Key for rate-limiting a device's pre-keys
///
/// @param sourceAci the account identifier of the authenticated device fetching the pre-keys
/// @param sourceDeviceId the deviceId of the authenticated device fetching pre-keys
/// @param targetIdentifier the [ServiceIdentifier] of the target account
/// @param targetDeviceId the deviceId of the target device, empty if fetching all device pre-keys
/// @param targetRegistrationId the registrationId of the target device, empty if fetching all device pre-keys
/// @return the rate-limit key
public static String preKeyLimiterKey(
final UUID sourceAci,
final byte sourceDeviceId,
final ServiceIdentifier targetIdentifier,
final Optional<Byte> targetDeviceId,
final Optional<Integer> targetRegistrationId) {
return String.format("%s.%s__%s.%s.%s",
sourceAci,
sourceDeviceId,
targetIdentifier.uuid(),
targetDeviceId.map(String::valueOf).orElse("*"),
targetRegistrationId.map(String::valueOf).orElse("*"));
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.StatusRuntimeException;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
@@ -31,6 +32,7 @@ import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.controllers.RateLimitKeys;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@@ -96,18 +98,28 @@ public class KeysGrpcService extends SimpleKeysGrpc.KeysImplBase {
final ServiceIdentifier targetIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier());
final Optional<Account> maybeTargetAccount = accountsManager.getByServiceIdentifier(targetIdentifier);
final byte deviceId = request.hasDeviceId()
? DeviceIdUtil.validate(request.getDeviceId())
: KeysGrpcHelper.ALL_DEVICES;
final String rateLimitKey = authenticatedDevice.accountIdentifier() + "." +
authenticatedDevice.deviceId() + "__" +
targetIdentifier.uuid() + "." +
deviceId;
final Optional<Integer> targetRegistrationId = maybeTargetAccount
.filter(_ -> request.hasDeviceId())
.flatMap(targetAccount -> targetAccount.getDevice(deviceId))
.map(device -> device.getRegistrationId(targetIdentifier.identityType()));
final String rateLimitKey = RateLimitKeys.preKeyLimiterKey(
authenticatedDevice.accountIdentifier(),
authenticatedDevice.deviceId(),
targetIdentifier,
Optional.ofNullable(request.hasDeviceId() ? deviceId : null),
targetRegistrationId);
rateLimiters.getPreKeysLimiter().validate(rateLimitKey);
return accountsManager.getByServiceIdentifier(targetIdentifier)
return maybeTargetAccount
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager))
.map(accountPreKeyBundles -> GetPreKeysResponse.newBuilder()
.setPreKeys(accountPreKeyBundles)

View File

@@ -14,6 +14,7 @@ import "google/protobuf/empty.proto";
import "org/signal/chat/common.proto";
import "org/signal/chat/errors.proto";
import "org/signal/chat/require.proto";
import "org/signal/chat/tag.proto";
// Provides methods for working with pre-keys.
service Keys {
@@ -152,7 +153,7 @@ message GetPreKeysResponse {
// Either the target account was not found, no active device with the given
// ID (if specified) was found on the target account.
errors.NotFound target_not_found = 2;
errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"];
}
}
@@ -163,10 +164,10 @@ message GetPreKeysAnonymousResponse {
// Either the target account was not found, no active device with the given
// ID (if specified) was found on the target account.
errors.NotFound target_not_found = 2;
errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"];
// The provided unidentified authorization credential was invalid
errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3;
errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3 [(tag.reason) = "failed_unidentified_authorization"];
}
}

View File

@@ -15,7 +15,6 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
@@ -38,6 +37,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
@@ -599,28 +599,48 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
assertTrue(response.hasTargetNotFound());
}
@Test
void getPreKeysRateLimited() throws RateLimitExceededException {
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(ECKeyPair.generate().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
@CartesianTest
void getPreKeysRateLimited(
@CartesianTest.Values(booleans = {true, false}) boolean allDevices,
@CartesianTest.Values(booleans = {true, false}) boolean targetMissing) throws RateLimitExceededException {
final UUID targetAccountId = UUID.randomUUID();
final byte targetDeviceId = 1;
final int registrationId = 123;
if (!targetMissing) {
final Device mockDevice = mock(Device.class);
when(mockDevice.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(targetAccountId);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(ECKeyPair.generate().getPublicKey()));
when(targetAccount.getDevice(targetDeviceId)).thenReturn(Optional.of(mockDevice));
when(targetAccount.getDevices()).thenReturn(List.of(mockDevice));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(targetAccountId)))
.thenReturn(Optional.of(targetAccount));
}
when(accountsManager.getByServiceIdentifier(any()))
.thenReturn(Optional.of(targetAccount));
final Duration retryAfterDuration = Duration.ofMinutes(7);
final String expectedRateLimitKey = AUTHENTICATED_ACI + "." +
AUTHENTICATED_DEVICE_ID + "__" +
targetAccountId + "." +
(allDevices ? "*" : targetDeviceId) + "." +
((allDevices || targetMissing) ? "*" : registrationId);
doThrow(new RateLimitExceededException(retryAfterDuration))
.when(preKeysRateLimiter).validate(anyString());
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
final GetPreKeysRequest.Builder builder = GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build()));
verifyNoInteractions(accountsManager);
.setUuid(UUIDUtil.toByteString(targetAccountId)));
if (!allDevices) {
builder.setDeviceId(targetDeviceId);
}
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(builder.build()));
verify(preKeysRateLimiter).validate(expectedRateLimitKey);
}
}