diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucketRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucketRateLimiter.java index d3d47f128..a4583f1b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucketRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucketRateLimiter.java @@ -68,7 +68,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { } @Override - public void validate(final String key, final int amount) throws RateLimitExceededException { + public void validate(final String key, final long amount) throws RateLimitExceededException { final RateLimiterConfig config = config(); try { final long deficitPermitsAmount = executeValidateScript(config, key, amount, true); @@ -90,7 +90,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { } @Override - public CompletionStage validateAsync(final String key, final int amount) { + public CompletionStage validateAsync(final String key, final long amount) { final RateLimiterConfig config = config(); return executeValidateScriptAsync(config, key, amount, true) @@ -117,7 +117,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { } @Override - public boolean hasAvailablePermits(final String key, final int permits) { + public boolean hasAvailablePermits(final String key, final long permits) { final RateLimiterConfig config = config(); try { final long deficitPermitsAmount = executeValidateScript(config, key, permits, false); @@ -132,7 +132,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { } @Override - public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { + public CompletionStage hasAvailablePermitsAsync(final String key, final long amount) { final RateLimiterConfig config = config(); return executeValidateScriptAsync(config, key, amount, false) .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0) @@ -162,7 +162,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { return configResolver.get(); } - private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) { + private long executeValidateScript(final RateLimiterConfig config, final String key, final long amount, final boolean applyChanges) { final List keys = List.of(bucketName(name, key)); final List arguments = List.of( String.valueOf(config.bucketSize()), @@ -174,7 +174,7 @@ public class LeakyBucketRateLimiter implements RateLimiter { return (Long) validateScript.execute(keys, arguments); } - private CompletionStage executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) { + private CompletionStage executeValidateScriptAsync(final RateLimiterConfig config, final String key, final long amount, final boolean applyChanges) { final List keys = List.of(bucketName(name, key)); final List arguments = List.of( String.valueOf(config.bucketSize()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index a9e1f265d..ca7a997d4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -13,13 +13,13 @@ import reactor.core.publisher.Mono; public interface RateLimiter { - void validate(String key, int amount) throws RateLimitExceededException; + void validate(String key, long amount) throws RateLimitExceededException; - CompletionStage validateAsync(String key, int amount); + CompletionStage validateAsync(String key, long amount); - boolean hasAvailablePermits(String key, int permits); + boolean hasAvailablePermits(String key, long permits); - CompletionStage hasAvailablePermitsAsync(String key, int amount); + CompletionStage hasAvailablePermitsAsync(String key, long amount); void clear(String key); @@ -35,7 +35,7 @@ public interface RateLimiter { validate(accountUuid.toString()); } - default void validate(final UUID accountUuid, final int permits) throws RateLimitExceededException { + default void validate(final UUID accountUuid, final long permits) throws RateLimitExceededException { validate(accountUuid.toString(), permits); } @@ -63,11 +63,11 @@ public interface RateLimiter { return validateReactive(accountUuid.toString()); } - default boolean hasAvailablePermits(final UUID accountUuid, final int permits) { + default boolean hasAvailablePermits(final UUID accountUuid, final long permits) { return hasAvailablePermits(accountUuid.toString(), permits); } - default CompletionStage hasAvailablePermitsAsync(final UUID accountUuid, final int permits) { + default CompletionStage hasAvailablePermitsAsync(final UUID accountUuid, final long permits) { return hasAvailablePermitsAsync(accountUuid.toString(), permits); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterConfig.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterConfig.java index 53f9f45b3..ec6eee417 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterConfig.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterConfig.java @@ -9,7 +9,7 @@ import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.AssertTrue; import java.time.Duration; -public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration, boolean failOpen) { +public record RateLimiterConfig(long bucketSize, Duration permitRegenerationDuration, boolean failOpen) { public double leakRatePerMillis() { return 1.0 / (permitRegenerationDuration.toNanos() / 1e6); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index ba621ad0d..655d1591a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -9,9 +9,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatException; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -29,7 +28,6 @@ import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import org.assertj.core.api.Assertions; @@ -618,12 +616,12 @@ public class BackupAuthManagerTest { final RateLimiters limiters = mock(RateLimiters.class); final RateLimiter allowLimiter = mock(RateLimiter.class); - when(allowLimiter.hasAvailablePermitsAsync(eq(aci), anyInt())).thenReturn(CompletableFuture.completedFuture(true)); + when(allowLimiter.hasAvailablePermitsAsync(eq(aci), anyLong())).thenReturn(CompletableFuture.completedFuture(true)); when(allowLimiter.validateAsync(aci)).thenReturn(CompletableFuture.completedFuture(null)); when(allowLimiter.config()).thenReturn(new RateLimiterConfig(1, Duration.ofDays(1), false)); final RateLimiter denyLimiter = mock(RateLimiter.class); - when(denyLimiter.hasAvailablePermitsAsync(eq(aci), anyInt())).thenReturn(CompletableFuture.completedFuture(false)); + when(denyLimiter.hasAvailablePermitsAsync(eq(aci), anyLong())).thenReturn(CompletableFuture.completedFuture(false)); when(denyLimiter.validateAsync(aci)) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); when(denyLimiter.config()).thenReturn(new RateLimiterConfig(1, Duration.ofDays(1), false)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java index 63cd263d0..8cbaadf8c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java @@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyCollection; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -328,7 +327,7 @@ class MessagesAnonymousGrpcServiceTest extends final Duration retryDuration = Duration.ofHours(7); - doThrow(new RateLimitExceededException(retryDuration)).when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyInt()); + doThrow(new RateLimitExceededException(retryDuration)).when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyLong()); final Map messages = Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java index c5f840dde..d86082d44 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -293,7 +293,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest messages = Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() @@ -555,7 +555,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest messages = Map.of(AUTHENTICATED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitChallengeOptionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitChallengeOptionManagerTest.java index 861364e31..e6e4e46d1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitChallengeOptionManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitChallengeOptionManagerTest.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.limits; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -52,12 +52,12 @@ class RateLimitChallengeOptionManagerTest { when(rateLimiters.getPushChallengeAttemptLimiter()).thenReturn(pushChallengeAttemptLimiter); when(rateLimiters.getPushChallengeSuccessLimiter()).thenReturn(pushChallengeSuccessLimiter); - when(captchaChallengeAttemptLimiter.hasAvailablePermits(any(UUID.class), anyInt())).thenReturn( + when(captchaChallengeAttemptLimiter.hasAvailablePermits(any(UUID.class), anyLong())).thenReturn( captchaAttemptPermitted); - when(captchaChallengeSuccessLimiter.hasAvailablePermits(any(UUID.class), anyInt())).thenReturn( + when(captchaChallengeSuccessLimiter.hasAvailablePermits(any(UUID.class), anyLong())).thenReturn( captchaSuccessPermitted); - when(pushChallengeAttemptLimiter.hasAvailablePermits(any(UUID.class), anyInt())).thenReturn(pushAttemptPermitted); - when(pushChallengeSuccessLimiter.hasAvailablePermits(any(UUID.class), anyInt())).thenReturn(pushSuccessPermitted); + when(pushChallengeAttemptLimiter.hasAvailablePermits(any(UUID.class), anyLong())).thenReturn(pushAttemptPermitted); + when(pushChallengeSuccessLimiter.hasAvailablePermits(any(UUID.class), anyLong())).thenReturn(pushSuccessPermitted); final int expectedLength = (expectCaptcha ? 1 : 0) + (expectPushChallenge ? 1 : 0);