Explicitly call spam-filter for challenges

Pass in the same information to the spam-filter, but just use explicit
method calls rather than jersey request filters.
This commit is contained in:
Ravi Khadiwala
2024-02-15 18:21:57 -06:00
committed by ravi-signal
parent 30b5ad1515
commit 4f40c128bf
7 changed files with 109 additions and 147 deletions

View File

@@ -23,15 +23,11 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import javax.ws.rs.client.Entity;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
@@ -40,9 +36,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@@ -51,26 +46,14 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class))
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
.addProvider(PushChallengeConfigProvider.PushChallengeConfigFeature.class)
.addProvider(new Feature() {
public boolean configure(FeatureContext featureContext) {
featureContext.register(new ContainerRequestFilter() {
public void filter(ContainerRequestContext requestContext) {
requestContext.setProperty(ScoreThreshold.PROPERTY_NAME, scoreThreshold.get());
}
});
return true;
}
})
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
@@ -78,10 +61,15 @@ class ChallengeControllerTest {
.addResource(challengeController)
.build();
@BeforeEach
void setup() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}
@AfterEach
void teardown() {
reset(rateLimitChallengeManager);
scoreThreshold.set(null);
reset(rateLimitChallengeManager, challengeConstraintChecker);
}
@Test
@@ -140,7 +128,8 @@ class ChallengeControllerTest {
if (hasThreshold) {
scoreThreshold.set(Float.valueOf(0.5f));
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.of(0.5F)));
}
final Response response = EXTENSION.target("/v1/challenge")
.request()
@@ -240,6 +229,40 @@ class ChallengeControllerTest {
}
}
@Test
void testRequestPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.text(""));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test
void testAnswerPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final String pushChallengeJson = """
{
"type": "rateLimitPushChallenge",
"challenge": "Hello I am a push challenge token"
}
""";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test
void testValidationError() {
final String unrecognizedJson = """