Add opt-in timeouts to provisioning websocket

This commit is contained in:
ravi-signal
2024-12-18 18:45:53 -06:00
committed by GitHub
parent 6460327372
commit 68f27be7cd
10 changed files with 310 additions and 37 deletions

View File

@@ -0,0 +1,180 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.EnumSet;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class ProvisioningTimeoutIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
client = new WebSocketClient();
client.start();
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
reset(testApplication.scheduler);
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestProvisioningListener extends TestWebsocketListener {
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
// ignore, this is the provisioning address the server sends on connect
return;
}
} catch (InvalidMessageException e) {
throw new RuntimeException(e);
}
super.onWebSocketBinary(payload, offset, length);
}
}
public static class TestApplication extends Application<Configuration> {
ScheduledExecutorService scheduler = mock(ScheduledExecutorService.class);
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.setConnectListener(
new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet = environment.servlets()
.addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
@Test
public void websocketTimeoutWithHeader() throws IOException {
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
when(testApplication.scheduler.schedule(any(Runnable.class), anyLong(), any()))
.thenReturn(mock(ScheduledFuture.class));
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
try (Session ignored = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join()) {
assertThat(testWebsocketListener.closeFuture()).isNotDone();
final ArgumentCaptor<Runnable> closeFunctionCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(testApplication.scheduler).schedule(closeFunctionCaptor.capture(), anyLong(), any());
closeFunctionCaptor.getValue().run();
assertThat(testWebsocketListener.closeFuture())
.succeedsWithin(Duration.ofSeconds(1))
.isEqualTo(1000);
}
}
@Test
public void websocketTimeoutNoHeader() throws IOException {
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
try (Session ignored = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join()) {
assertThat(testWebsocketListener.closeFuture()).isNotDone();
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
verify(testApplication.scheduler, never()).schedule(any(Runnable.class), anyLong(), any());
assertThat(testWebsocketListener.closeFuture()).isNotDone();
}
}
@Test
public void websocketTimeoutCancelled() throws IOException {
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
@SuppressWarnings("unchecked") final ScheduledFuture<Void> scheduled = mock(ScheduledFuture.class);
doReturn(scheduled).when(testApplication.scheduler).schedule(any(Runnable.class), anyLong(), any());
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
final Session session = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join();
// Close the websocket, make sure the timeout is cancelled.
session.close();
assertThat(testWebsocketListener.closeFuture()).succeedsWithin(Duration.ofSeconds(1));
verify(scheduled, times(1)).cancel(anyBoolean());
}
}

View File

@@ -19,7 +19,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@@ -132,7 +131,7 @@ import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.Stories;
import org.whispersystems.websocket.WebsocketHeaders;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
@@ -675,7 +674,7 @@ class MessageControllerTest {
resources.getJerseyTest().target("/v1/messages/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
.header(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
.header(HttpHeaders.USER_AGENT, userAgent)
.accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class);

View File

@@ -24,8 +24,9 @@ public class TestWebsocketListener implements WebSocketListener {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
private final CompletableFuture<Integer> closed = new CompletableFuture<>();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
private final WebSocketMessageFactory messageFactory;
protected final WebSocketMessageFactory messageFactory;
public TestWebsocketListener() {
this.messageFactory = new ProtobufWebSocketMessageFactory();
@@ -38,6 +39,15 @@ public class TestWebsocketListener implements WebSocketListener {
}
@Override
public void onWebSocketClose(int statusCode, String reason) {
closed.complete(statusCode);
}
public CompletableFuture<Integer> closeFuture() {
return closed;
}
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
}

View File

@@ -1,6 +1,24 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.InvalidProtocolBufferException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
@@ -9,23 +27,20 @@ import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
class ProvisioningConnectListenerTest {
private ProvisioningManager provisioningManager;
private ProvisioningConnectListener provisioningConnectListener;
private ScheduledExecutorService scheduledExecutorService;
private static Duration TIMEOUT = Duration.ofSeconds(5);
@BeforeEach
void setUp() {
provisioningManager = mock(ProvisioningManager.class);
provisioningConnectListener = new ProvisioningConnectListener(provisioningManager);
scheduledExecutorService = mock(ScheduledExecutorService.class);
provisioningConnectListener =
new ProvisioningConnectListener(provisioningManager, scheduledExecutorService, TIMEOUT);
}
@Test
@@ -60,4 +75,49 @@ class ProvisioningConnectListenerTest {
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
}
@Test
void schedulesTimeout() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
final ArgumentCaptor<Runnable> scheduleCaptor = ArgumentCaptor.forClass(Runnable.class);
provisioningConnectListener.onWebSocketConnect(context);
verify(scheduledExecutorService).schedule(scheduleCaptor.capture(), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
verify(webSocketClient, never()).close(anyInt(), any());
scheduleCaptor.getValue().run();
verify(webSocketClient, times(1)).close(eq(1000), anyString());
}
@Test
void cancelsTimeout() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
provisioningConnectListener.onWebSocketConnect(context);
verify(scheduledExecutorService).schedule(any(Runnable.class), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
context.notifyClosed(1000, "Test");
verify(scheduledFuture).cancel(false);
verify(webSocketClient, never()).close(anyInt(), any());
}
@Test
void skipsTimeoutIfUnsupported() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
provisioningConnectListener.onWebSocketConnect(context);
verify(scheduledExecutorService, never())
.schedule(any(Runnable.class), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
}
}