mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-27 23:23:13 +01:00
Match KeyController's rate limit key in KeysGrpcService
This commit is contained in:
@@ -398,34 +398,38 @@ public class KeysController {
|
||||
}
|
||||
|
||||
private List<Device> parseDeviceId(String deviceId, Account account) {
|
||||
return parseDeviceId(deviceId)
|
||||
.map(id -> account.getDevice(id).map(List::of).orElse(List.of()))
|
||||
.orElseGet(account::getDevices);
|
||||
}
|
||||
|
||||
private Optional<Byte> parseDeviceId(final String deviceId) {
|
||||
if (deviceId.equals("*")) {
|
||||
return account.getDevices();
|
||||
return Optional.empty();
|
||||
}
|
||||
try {
|
||||
byte id = Byte.parseByte(deviceId);
|
||||
return account.getDevice(id).map(List::of).orElse(List.of());
|
||||
return Optional.of(Byte.parseByte(deviceId));
|
||||
} catch (NumberFormatException e) {
|
||||
throw new WebApplicationException(Response.status(422).build());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private String getPreKeysLimiterKey(
|
||||
final Account account,
|
||||
final AuthenticatedDevice authenticatedDevice,
|
||||
final ServiceIdentifier targetIdentifier,
|
||||
final Account targetAccount,
|
||||
final String targetDeviceId) {
|
||||
final String targetRegistrationId = targetDeviceId.equals("*")
|
||||
? "*"
|
||||
: String.valueOf(
|
||||
parseDeviceId(targetDeviceId, targetAccount).getFirst().getRegistrationId(targetIdentifier.identityType()));
|
||||
|
||||
return String.format("%s.%s__%s.%s.%s",
|
||||
final Optional<Byte> parsedTargetDeviceId = parseDeviceId(targetDeviceId);
|
||||
final Optional<Integer> targetRegistrationId = parsedTargetDeviceId
|
||||
.flatMap(targetAccount::getDevice)
|
||||
.map(device -> device.getRegistrationId(targetIdentifier.identityType()));
|
||||
return RateLimitKeys.preKeyLimiterKey(
|
||||
account.getUuid(),
|
||||
authenticatedDevice.deviceId(),
|
||||
targetIdentifier.uuid(),
|
||||
targetDeviceId,
|
||||
targetRegistrationId
|
||||
);
|
||||
targetIdentifier,
|
||||
parsedTargetDeviceId, targetRegistrationId);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2026 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
|
||||
/// Helper methods to centrally define rate-limit key formats used in multiple places
|
||||
public class RateLimitKeys {
|
||||
|
||||
/// Key for rate-limiting a device's pre-keys
|
||||
///
|
||||
/// @param sourceAci the account identifier of the authenticated device fetching the pre-keys
|
||||
/// @param sourceDeviceId the deviceId of the authenticated device fetching pre-keys
|
||||
/// @param targetIdentifier the [ServiceIdentifier] of the target account
|
||||
/// @param targetDeviceId the deviceId of the target device, empty if fetching all device pre-keys
|
||||
/// @param targetRegistrationId the registrationId of the target device, empty if fetching all device pre-keys
|
||||
/// @return the rate-limit key
|
||||
public static String preKeyLimiterKey(
|
||||
final UUID sourceAci,
|
||||
final byte sourceDeviceId,
|
||||
final ServiceIdentifier targetIdentifier,
|
||||
final Optional<Byte> targetDeviceId,
|
||||
final Optional<Integer> targetRegistrationId) {
|
||||
return String.format("%s.%s__%s.%s.%s",
|
||||
sourceAci,
|
||||
sourceDeviceId,
|
||||
targetIdentifier.uuid(),
|
||||
targetDeviceId.map(String::valueOf).orElse("*"),
|
||||
targetRegistrationId.map(String::valueOf).orElse("*"));
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.BiFunction;
|
||||
@@ -31,6 +32,7 @@ import org.signal.libsignal.protocol.kem.KEMPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitKeys;
|
||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
@@ -96,18 +98,28 @@ public class KeysGrpcService extends SimpleKeysGrpc.KeysImplBase {
|
||||
final ServiceIdentifier targetIdentifier =
|
||||
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier());
|
||||
|
||||
|
||||
final Optional<Account> maybeTargetAccount = accountsManager.getByServiceIdentifier(targetIdentifier);
|
||||
|
||||
final byte deviceId = request.hasDeviceId()
|
||||
? DeviceIdUtil.validate(request.getDeviceId())
|
||||
: KeysGrpcHelper.ALL_DEVICES;
|
||||
|
||||
final String rateLimitKey = authenticatedDevice.accountIdentifier() + "." +
|
||||
authenticatedDevice.deviceId() + "__" +
|
||||
targetIdentifier.uuid() + "." +
|
||||
deviceId;
|
||||
final Optional<Integer> targetRegistrationId = maybeTargetAccount
|
||||
.filter(_ -> request.hasDeviceId())
|
||||
.flatMap(targetAccount -> targetAccount.getDevice(deviceId))
|
||||
.map(device -> device.getRegistrationId(targetIdentifier.identityType()));
|
||||
|
||||
final String rateLimitKey = RateLimitKeys.preKeyLimiterKey(
|
||||
authenticatedDevice.accountIdentifier(),
|
||||
authenticatedDevice.deviceId(),
|
||||
targetIdentifier,
|
||||
Optional.ofNullable(request.hasDeviceId() ? deviceId : null),
|
||||
targetRegistrationId);
|
||||
|
||||
rateLimiters.getPreKeysLimiter().validate(rateLimitKey);
|
||||
|
||||
return accountsManager.getByServiceIdentifier(targetIdentifier)
|
||||
return maybeTargetAccount
|
||||
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager))
|
||||
.map(accountPreKeyBundles -> GetPreKeysResponse.newBuilder()
|
||||
.setPreKeys(accountPreKeyBundles)
|
||||
|
||||
@@ -14,6 +14,7 @@ import "google/protobuf/empty.proto";
|
||||
import "org/signal/chat/common.proto";
|
||||
import "org/signal/chat/errors.proto";
|
||||
import "org/signal/chat/require.proto";
|
||||
import "org/signal/chat/tag.proto";
|
||||
|
||||
// Provides methods for working with pre-keys.
|
||||
service Keys {
|
||||
@@ -152,7 +153,7 @@ message GetPreKeysResponse {
|
||||
|
||||
// Either the target account was not found, no active device with the given
|
||||
// ID (if specified) was found on the target account.
|
||||
errors.NotFound target_not_found = 2;
|
||||
errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,10 +164,10 @@ message GetPreKeysAnonymousResponse {
|
||||
|
||||
// Either the target account was not found, no active device with the given
|
||||
// ID (if specified) was found on the target account.
|
||||
errors.NotFound target_not_found = 2;
|
||||
errors.NotFound target_not_found = 2 [(tag.reason) = "not_found"];
|
||||
|
||||
// The provided unidentified authorization credential was invalid
|
||||
errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3;
|
||||
errors.FailedUnidentifiedAuthorization failed_unidentified_authorization = 3 [(tag.reason) = "failed_unidentified_authorization"];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
@@ -38,6 +37,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.common.EcPreKey;
|
||||
import org.signal.chat.common.EcSignedPreKey;
|
||||
@@ -599,28 +599,48 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
|
||||
assertTrue(response.hasTargetNotFound());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeysRateLimited() throws RateLimitExceededException {
|
||||
final Account targetAccount = mock(Account.class);
|
||||
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(ECKeyPair.generate().getPublicKey()));
|
||||
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
|
||||
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
|
||||
@CartesianTest
|
||||
void getPreKeysRateLimited(
|
||||
@CartesianTest.Values(booleans = {true, false}) boolean allDevices,
|
||||
@CartesianTest.Values(booleans = {true, false}) boolean targetMissing) throws RateLimitExceededException {
|
||||
final UUID targetAccountId = UUID.randomUUID();
|
||||
final byte targetDeviceId = 1;
|
||||
final int registrationId = 123;
|
||||
|
||||
if (!targetMissing) {
|
||||
final Device mockDevice = mock(Device.class);
|
||||
when(mockDevice.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
|
||||
|
||||
final Account targetAccount = mock(Account.class);
|
||||
when(targetAccount.getUuid()).thenReturn(targetAccountId);
|
||||
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(ECKeyPair.generate().getPublicKey()));
|
||||
when(targetAccount.getDevice(targetDeviceId)).thenReturn(Optional.of(mockDevice));
|
||||
when(targetAccount.getDevices()).thenReturn(List.of(mockDevice));
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(targetAccountId)))
|
||||
.thenReturn(Optional.of(targetAccount));
|
||||
}
|
||||
|
||||
when(accountsManager.getByServiceIdentifier(any()))
|
||||
.thenReturn(Optional.of(targetAccount));
|
||||
|
||||
final Duration retryAfterDuration = Duration.ofMinutes(7);
|
||||
|
||||
final String expectedRateLimitKey = AUTHENTICATED_ACI + "." +
|
||||
AUTHENTICATED_DEVICE_ID + "__" +
|
||||
targetAccountId + "." +
|
||||
(allDevices ? "*" : targetDeviceId) + "." +
|
||||
((allDevices || targetMissing) ? "*" : registrationId);
|
||||
|
||||
|
||||
doThrow(new RateLimitExceededException(retryAfterDuration))
|
||||
.when(preKeysRateLimiter).validate(anyString());
|
||||
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
final GetPreKeysRequest.Builder builder = GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
verifyNoInteractions(accountsManager);
|
||||
.setUuid(UUIDUtil.toByteString(targetAccountId)));
|
||||
if (!allDevices) {
|
||||
builder.setDeviceId(targetDeviceId);
|
||||
}
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(builder.build()));
|
||||
verify(preKeysRateLimiter).validate(expectedRateLimitKey);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user