Add ExternalRequestFilter

This commit is contained in:
Chris Eager
2024-01-29 16:48:56 -06:00
committed by Chris Eager
parent 63c8b275d1
commit 5b97bc04e0
10 changed files with 581 additions and 1 deletions

View File

@@ -0,0 +1,245 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.protobuf.ByteString;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.DropwizardTestSupport;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import java.net.InetAddress;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.servlet.DispatcherType;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.client.Client;
import javax.ws.rs.core.Response;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class)
class ExternalRequestFilterTest {
@Nested
class Allowed extends TestCase {
@Override
DropwizardTestSupport<TestConfiguration> getTestSupport() {
return new DropwizardTestSupport<>(TestApplication.class, getConfiguration());
}
@Override
int getExpectedHttpStatus() {
return 200;
}
@Override
Status getExpectedGrpcStatus() {
return Status.OK;
}
@Override
TestConfiguration getConfiguration() {
return new TestConfiguration() {
@Override
public Set<InetAddressRange> getPermittedRanges() {
return Set.of(new InetAddressRange("127.0.0.0/8"));
}
};
}
}
@Nested
class Blocked extends TestCase {
@Override
DropwizardTestSupport<TestConfiguration> getTestSupport() {
return new DropwizardTestSupport<>(TestApplication.class, getConfiguration());
}
@Override
int getExpectedHttpStatus() {
return 404;
}
@Override
Status getExpectedGrpcStatus() {
return Status.NOT_FOUND;
}
@Override
TestConfiguration getConfiguration() {
return new TestConfiguration() {
@Override
public Set<InetAddressRange> getPermittedRanges() {
return Set.of(new InetAddressRange("10.0.0.0/8"));
}
};
}
}
abstract static class TestCase {
abstract DropwizardTestSupport<TestConfiguration> getTestSupport();
abstract TestConfiguration getConfiguration();
abstract int getExpectedHttpStatus();
abstract Status getExpectedGrpcStatus();
private Server testServer;
private ManagedChannel channel;
@Nested
class Http {
private final DropwizardAppExtension<TestConfiguration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(getTestSupport());
@Test
void testRestricted() {
Client client = DROPWIZARD_APP_EXTENSION.client();
try (Response response = client.target(
"http://localhost:%s/test/restricted".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort()))
.request()
.get()) {
assertEquals(getExpectedHttpStatus(), response.getStatus());
}
}
@Test
void testOpen() {
Client client = DROPWIZARD_APP_EXTENSION.client();
try (Response response = client.target(
"http://localhost:%s/test/open".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort()))
.request()
.get()) {
assertEquals(200, response.getStatus());
}
}
}
@Nested
class Grpc {
@BeforeEach
void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1"));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor()
.addService(new EchoServiceImpl())
.intercept(new ExternalRequestFilter(getConfiguration().getPermittedRanges(),
Set.of("org.signal.chat.rpc.EchoService/echo2")))
.intercept(mockRequestAttributesInterceptor)
.build()
.start();
channel = InProcessChannelBuilder.forName("ExternalRequestFilterTest")
.directExecutor()
.build();
}
@Test
void testBlocked() {
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
final String text = "0123456789";
final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8(text)).build();
final Status expectedGrpcStatus = getExpectedGrpcStatus();
if (Status.Code.OK == expectedGrpcStatus.getCode()) {
assertEquals(text, client.echo2(req).getPayload().toStringUtf8());
} else {
GrpcTestUtils.assertStatusException(expectedGrpcStatus, () -> client.echo2(req));
}
}
@Test
void testOpen() {
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
final String text = "0123456789";
final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8(text)).build();
assertEquals(text, client.echo(req).getPayload().toStringUtf8());
}
@AfterEach
void tearDown() throws Exception {
testServer.shutdownNow()
.awaitTermination(10, TimeUnit.SECONDS);
}
}
@Path("/test")
public static class Controller {
@GET
@Path("/restricted")
public Response restricted() {
return Response.ok().build();
}
@GET
@Path("/open")
public Response open() {
return Response.ok().build();
}
}
public static class TestApplication extends Application<TestConfiguration> {
@Override
public void run(final TestConfiguration configuration, final Environment environment) throws Exception {
environment.jersey().register(new Controller());
environment.servlets()
.addFilter("ExternalRequestFilter",
new ExternalRequestFilter(configuration.getPermittedRanges(),
Collections.emptySet()))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/test/restricted");
}
}
public abstract static class TestConfiguration extends Configuration {
public abstract Set<InetAddressRange> getPermittedRanges();
}
}
}

View File

@@ -6,9 +6,9 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
@Override
@@ -16,4 +16,10 @@ public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
responseObserver.onCompleted();
}
@Override
public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
responseObserver.onCompleted();
}
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Map;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
class InetAddressRangeTest {
@ParameterizedTest
@ValueSource(strings = {"192.168.0.1", "192.168.0.0/33", "$%#*(@!&^$/24", "192.168.0.0/fish", "signal.org"})
void testBogusCidrBlock(final String cidrBlock) {
assertThrows(IllegalArgumentException.class, () -> new InetAddressRange(cidrBlock));
}
@ParameterizedTest
@MethodSource("argumentsForTestGeneratePrefixMask")
void testGeneratePrefixMask(final int addressLengthBytes, final int prefixLengthBits, final byte[] expectedMask) {
assertArrayEquals(expectedMask, InetAddressRange.generatePrefixMask(addressLengthBytes, prefixLengthBits));
}
private static Stream<Arguments> argumentsForTestGeneratePrefixMask() {
return Stream.of(
Arguments.of(4, 32, new byte[]{(byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff}),
Arguments.of(4, 24, new byte[]{(byte) 0xff, (byte) 0xff, (byte) 0xff, 0x00}),
Arguments.of(4, 22, new byte[]{(byte) 0xff, (byte) 0xff, (byte) 0xfc, 0x00}),
Arguments.of(4, 0, new byte[]{0x00, 0x00, 0x00, 0x00})
);
}
@ParameterizedTest
@MethodSource("argumentsForTestContains")
void testContains(final String cidrBlock, final String address, final boolean expectContains) {
assertEquals(expectContains, new InetAddressRange(cidrBlock).contains(address));
}
private static Stream<Arguments> argumentsForTestContains() {
return Stream.of(
Arguments.of("192.168.0.0/24", "192.168.0.1", true),
Arguments.of("192.168.0.0/24", "192.168.1.0", false),
Arguments.of("192.168.0.1/32", "192.168.0.1", true),
Arguments.of("192.168.0.1/32", "192.168.0.0", false),
Arguments.of("2001:db8::/48", "2001:db8:0:0:0:0:0:0", true),
Arguments.of("2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff", true),
Arguments.of("2001:db8::/48", "2001:db6:0:ffff:ffff:ffff:ffff:ffff", false)
);
}
@Test
void testContainsMismatchedAddressType() {
assertFalse(new InetAddressRange("192.168.0.0/24").contains("2001:db8:0:0:0:0:0:0"));
assertFalse(new InetAddressRange("2001:db8::/48").contains("192.168.0.1"));
}
@Test
void testDeserialize() throws JsonProcessingException {
final TypeReference<Map<String, InetAddressRange>> typeReference = new TypeReference<>() {};
assertEquals(Map.of("range", new InetAddressRange("192.168.0.0/24")),
new ObjectMapper().readValue("{\"range\":\"192.168.0.0/24\"}", typeReference));
}
}