Don't cache authenticated accounts in memory

This commit is contained in:
Jon Chambers
2025-06-23 08:40:05 -05:00
committed by GitHub
parent 9dfe51eac4
commit c952baa672
86 changed files with 961 additions and 2264 deletions

View File

@@ -21,6 +21,7 @@ import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
@@ -34,9 +35,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@@ -78,8 +77,7 @@ public class WebsocketResourceProviderIntegrationTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedDevice.class), PrincipalSupplier.forImmutablePrincipal()));
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {

View File

@@ -1,279 +0,0 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedDevice ACCOUNT = mock(AuthenticatedDevice.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedDevice> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedDevice copiedAccount = mock(AuthenticatedDevice.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedDevice account2 = mock(AuthenticatedDevice.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedDevice deepCopy = mock(AuthenticatedDevice.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedDevice refresh = mock(AuthenticatedDevice.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedDevice account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@@ -18,7 +18,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class CertificateGeneratorTest {
@@ -37,7 +36,7 @@ class CertificateGeneratorTest {
@Test
void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final byte deviceId = 4;
final CertificateGenerator certificateGenerator = new CertificateGenerator(
Base64.getDecoder().decode(SIGNING_CERTIFICATE),
Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1);
@@ -45,9 +44,8 @@ class CertificateGeneratorTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn((byte) 4);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, true).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, false).length > 0);
}
}

View File

@@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
@@ -59,9 +57,9 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
final boolean expectPqKeyCheck,
@Nullable final String expectedAlertHeader) {
final ReusableAuth<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? ReusableAuth.authenticated(authenticatedDevice, PrincipalSupplier.forImmutablePrincipal())
: ReusableAuth.anonymous();
final Optional<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? Optional.of(authenticatedDevice)
: Optional.empty();
final JettyServerUpgradeResponse response = mock(JettyServerUpgradeResponse.class);
@@ -88,20 +86,24 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
private static List<Arguments> handleAuthentication() {
final Device activePrimaryDevice = mock(Device.class);
when(activePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(activePrimaryDevice.isPrimary()).thenReturn(true);
when(activePrimaryDevice.getLastSeen()).thenReturn(CLOCK.millis());
final Device minIdlePrimaryDevice = mock(Device.class);
when(minIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(minIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(minIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(MIN_IDLE_DURATION).minusSeconds(1).toEpochMilli());
final Device longIdlePrimaryDevice = mock(Device.class);
when(longIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(longIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(longIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.PQ_KEY_CHECK_THRESHOLD).minusSeconds(1).toEpochMilli());
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(linkedDevice.isPrimary()).thenReturn(false);
final Account accountWithActivePrimaryDevice = mock(Account.class);

View File

@@ -1,328 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class LinkedDeviceRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
private final ResourceExtension resources = ResourceExtension.builder()
.addProvider(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<TestPrincipal>()
.setAuthenticator(c -> principalSupplier.get()).buildAuthFilter()))
.addProvider(new AuthValueFactoryProvider.Binder<>(TestPrincipal.class))
.addProvider(applicationEventListener)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new TestResource())
.build();
private AccountsManager accountsManager;
private DisconnectionRequestManager disconnectionRequestManager;
private LinkedDeviceRefreshRequirementProvider provider;
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(disconnectionRequestManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
IntStream.range(2, 4)
.forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId)));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}
@Test
void testDeviceAdded() {
final int initialDeviceCount = account.getDevices().size();
final List<String> addedDeviceNames = List.of(
Base64.getEncoder().encodeToString("newDevice1".getBytes(StandardCharsets.UTF_8)),
Base64.getEncoder().encodeToString("newDevice2".getBytes(StandardCharsets.UTF_8)));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON));
assertEquals(200, response.getStatus());
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
}
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Byte> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != Device.PRIMARY_ID)
.limit(removedDeviceCount)
.toList();
assert deletedDeviceIds.size() == removedDeviceCount;
final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf)
.collect(Collectors.joining(","));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/" + deletedDeviceIdsParam)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.delete();
assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId ->
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId)));
verifyNoMoreInteractions(disconnectionRequestManager);
}
@Test
void testOnEvent() {
Response response = resources.getJerseyTest()
.target("/v1/test/hello")
.request()
// no authorization required
.get();
assertEquals(200, response.getStatus());
response = resources.getJerseyTest()
.target("/v1/test/authorized")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.get();
assertEquals(200, response.getStatus());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
}
@Nested
class WebSocket {
private WebSocketResourceProvider<TestPrincipal> provider;
private RemoteEndpoint remoteEndpoint;
@BeforeEach
void setup() {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(applicationEventListener);
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(SystemMapper.jsonMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
Session session = mock(Session.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketConnect(session);
}
@Test
void testOnEvent() throws Exception {
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
}
private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint)
throws IOException {
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse();
}
}
public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder {
private final String name;
private final Account account;
private final Device device;
private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
}
@Override
public String getName() {
return name;
}
@Override
public Account getAccount() {
return account;
}
@Override
public Device getAuthenticatedDevice() {
return device;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
}
@Path("/v1/test")
public static class TestResource {
@GET
@Path("/hello")
public String testGetHello() {
return "Hello!";
}
@GET
@Path("/authorized")
public String testAuth(@Auth TestPrincipal principal) {
return "Youre in!";
}
@PUT
@Path("/account/devices")
@ChangesLinkedDevices
public String addDevices(@Auth TestPrincipal auth, List<byte[]> deviceNames) {
deviceNames.forEach(name -> {
final Device device = DevicesHelper.createDevice(auth.getAccount().getNextDeviceId());
auth.getAccount().addDevice(device);
device.setName(name);
});
return "Added devices " + deviceNames;
}
@DELETE
@Path("/account/devices/{deviceIds}")
@ChangesLinkedDevices
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Byte::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
}
}
}

View File

@@ -1,294 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Invocation;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest {
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final DisconnectionRequestManager DISCONNECTION_REQUEST_MANAGER =
mock(DisconnectionRequestManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach
void setUp() throws Exception {
reset(AUTHENTICATOR, ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER);
client = new WebSocketClient();
client.start();
final UUID uuid = UUID.randomUUID();
account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
account2.setUuid(uuid);
account2.addDevice(authenticatedDevice);
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedDevice>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@Test
void handleRequestChangeAsyncEndpoint() throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
// response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
makeAnonymousRequest(protocol, "/test/not-authenticated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@Path("/test")
public static class TestController {
@GET
@Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/async-annotated")
@ChangesPhoneNumber
@ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
}
}

View File

@@ -211,6 +211,11 @@ class AccountControllerTest {
when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifier(AuthHelper.UNDISCOVERABLE_UUID)).thenReturn(Optional.of(AuthHelper.UNDISCOVERABLE_ACCOUNT));
doAnswer(invocation -> {
final byte[] proof = invocation.getArgument(0);
final byte[] hash = invocation.getArgument(1);

View File

@@ -145,6 +145,8 @@ class AccountControllerV2Test {
void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@@ -607,6 +609,8 @@ class AccountControllerV2Test {
@BeforeEach
void setUp() throws Exception {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@@ -768,7 +772,9 @@ class AccountControllerV2Test {
@BeforeEach
void setup() {
AccountsHelper.setupMockUpdate(accountsManager);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testSetPhoneNumberDiscoverability() {
Response response = resources.getJerseyTest()
@@ -805,6 +811,7 @@ class AccountControllerV2Test {
@MethodSource
void testGetAccountDataReport(final Account account, final String expectedTextAfterHeader) throws Exception {
when(AuthHelper.ACCOUNTS_MANAGER.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
final Response response = resources.getJerseyTest()
.target("/v2/accounts/data_report")

View File

@@ -73,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -82,6 +83,7 @@ import reactor.core.publisher.Flux;
@ExtendWith(DropwizardExtensionsSupport.class)
public class ArchiveControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private static final BackupManager backupManager = mock(BackupManager.class);
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(Clock.systemUTC());
@@ -95,7 +97,7 @@ public class ArchiveControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ArchiveController(backupAuthManager, backupManager, new BackupMetrics()))
.addResource(new ArchiveController(accountsManager, backupAuthManager, backupManager, new BackupMetrics()))
.build();
private final UUID aci = UUID.randomUUID();
@@ -106,6 +108,9 @@ public class ArchiveControllerTest {
public void setUp() {
reset(backupAuthManager);
reset(backupManager);
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
}
@ParameterizedTest

View File

@@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@@ -21,9 +23,11 @@ import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Optional;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
@@ -44,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -66,6 +71,8 @@ class CertificateControllerTest {
private static final ServerZkAuthOperations serverZkAuthOperations;
private static final Clock clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private static final AccountsManager accountsManager = mock(AccountsManager.class);
static {
try {
certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate),
@@ -82,9 +89,14 @@ class CertificateControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.addResource(new CertificateController(accountsManager, certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.build();
@BeforeEach
void setUp() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testValidCertificate() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()

View File

@@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@@ -45,11 +46,12 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
@@ -63,6 +65,9 @@ class ChallengeControllerTest {
@BeforeEach
void setup() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}

View File

@@ -27,6 +27,7 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.glassfish.jersey.server.ServerProperties;
@@ -45,6 +46,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException;
import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException;
@@ -62,6 +64,7 @@ class DeviceCheckControllerTest {
private final static Duration REDEMPTION_DURATION = Duration.ofDays(5);
private final static long REDEMPTION_LEVEL = 201L;
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private final static BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private final static AppleDeviceCheckManager appleDeviceCheckManager = mock(AppleDeviceCheckManager.class);
private final static RateLimiters rateLimiters = mock(RateLimiters.class);
@@ -76,7 +79,7 @@ class DeviceCheckControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
.addResource(new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters,
REDEMPTION_LEVEL, REDEMPTION_DURATION))
.build();
@@ -86,6 +89,8 @@ class DeviceCheckControllerTest {
reset(appleDeviceCheckManager);
reset(rateLimiters);
when(rateLimiters.forDescriptor(any())).thenReturn(mock(RateLimiter.class));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@ParameterizedTest

View File

@@ -42,14 +42,12 @@ import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
@@ -62,8 +60,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
@@ -116,7 +112,6 @@ class DeviceControllerTest {
private static final Account account = mock(Account.class);
private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class);
private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class);
private static final Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now();
@@ -129,16 +124,12 @@ class DeviceControllerTest {
persistentTimer,
deviceConfiguration);
@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController)
.build();
@@ -157,8 +148,15 @@ class DeviceControllerTest {
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.getPrimaryDevice()).thenReturn(primaryDevice);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
@@ -229,7 +227,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -310,7 +308,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -362,7 +360,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -398,7 +396,7 @@ class DeviceControllerTest {
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -447,7 +445,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -487,12 +485,12 @@ class DeviceControllerTest {
void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@@ -548,12 +546,12 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@@ -613,11 +611,11 @@ class DeviceControllerTest {
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@@ -647,11 +645,11 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
@@ -698,7 +696,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -735,7 +733,7 @@ class DeviceControllerTest {
void linkDeviceRegistrationId(final int registrationId, final int pniRegistrationId, final int expectedStatusCode) {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -800,17 +798,16 @@ class DeviceControllerTest {
@Test
void maxDevicesTest() {
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
.mapToObj(i -> mock(Device.class))
.toList();
when(testAccount.account.getDevices()).thenReturn(devices);
when(account.getDevices()).thenReturn(devices);
Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", testAccount.getAuthHeader())
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertEquals(411, response.getStatus());
@@ -829,7 +826,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(AuthHelper.VALID_DEVICE).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
verify(primaryDevice).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
}
}
@@ -851,12 +848,12 @@ class DeviceControllerTest {
void removeDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
when(accountsManager.removeDevice(account, deviceId))
.thenReturn(CompletableFuture.completedFuture(account));
try (final Response response = resources
.getJerseyTest()
@@ -869,14 +866,14 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT, deviceId);
verify(accountsManager).removeDevice(account, deviceId);
}
}
@Test
void unlinkPrimaryDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
try (final Response response = resources
.getJerseyTest()
@@ -897,7 +894,10 @@ class DeviceControllerTest {
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
.thenReturn(CompletableFuture.completedFuture(account));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3))
.thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
try (final Response response = resources
.getJerseyTest()
@@ -946,7 +946,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
}
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
verify(clientPublicKeysManager).setPublicKey(account, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
}
@Test
@@ -959,7 +959,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@@ -985,7 +985,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@@ -1005,7 +1005,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@@ -1079,7 +1079,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@@ -1092,7 +1092,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive);
}
}
@@ -1103,7 +1103,7 @@ class DeviceControllerTest {
final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD);
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@@ -1116,7 +1116,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure);
}
}
@@ -1186,7 +1186,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@@ -1206,7 +1206,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@@ -1223,7 +1223,7 @@ class DeviceControllerTest {
@Test
void waitForTransferArchiveNoArchiveUploaded() {
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()

View File

@@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -35,7 +36,7 @@ class DirectoryControllerV2Test {
final Account account = mock(Account.class);
final UUID uuid = UUID.fromString("11111111-1111-1111-1111-111111111111");
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
final ExternalServiceCredentials credentials = controller.getAuthToken(
new AuthenticatedDevice(account, mock(Device.class)));

View File

@@ -140,6 +140,8 @@ class DonationControllerTest {
when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration);
when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Boolean.FALSE));
when(accountsManager.getByAccountIdentifierAsync(eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true);
Response response = resources.getJerseyTest()

View File

@@ -242,10 +242,21 @@ class KeysControllerTest {
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accounts.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifierAsync(new AciServiceIdentifier(EXISTS_UUID)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accounts.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any()))

View File

@@ -236,6 +236,14 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID_3))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT_3)));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);

View File

@@ -221,6 +221,8 @@ class ProfileControllerTest {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
@@ -1155,6 +1157,7 @@ class ProfileControllerTest {
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(

View File

@@ -170,7 +170,12 @@ class RemoteConfigControllerTest {
void testHashKeyLinkedConfigs() {
boolean allUnlinkedConfigsMatched = true;
for (AuthHelper.TestAccount testAccount : AuthHelper.TEST_ACCOUNTS) {
UserRemoteConfigList configuration = resources.getJerseyTest().target("/v1/config/").request().header("Authorization", testAccount.getAuthHeader()).get(UserRemoteConfigList.class);
UserRemoteConfigList configuration = resources.getJerseyTest()
.target("/v1/config/")
.request()
.header("Authorization", testAccount.getAuthHeader())
.get(UserRemoteConfigList.class);
assertThat(configuration.getConfig()).hasSize(11);
final UserRemoteConfig linkedConfig0 = configuration.getConfig().get(7);

View File

@@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.entities;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Random;
import java.util.UUID;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
@@ -16,7 +15,6 @@ import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@@ -65,15 +63,13 @@ class OutgoingMessageEntityTest {
@Test
void entityPreservesEnvelope() {
final byte[] reportSpamToken = TestRandomUtil.nextBytes(8);
final AciServiceIdentifier sourceServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
final IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),
account,
sourceServiceIdentifier,
(byte) 123,
System.currentTimeMillis(),
false,

View File

@@ -153,7 +153,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@@ -220,7 +220,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);

View File

@@ -220,6 +220,8 @@ public class AccountsHelper {
case "getBackupVoucher" -> when(updatedAccount.getBackupVoucher()).thenAnswer(stubbing);
case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing);
case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing);
case "getCurrentProfileVersion" -> when(updatedAccount.getCurrentProfileVersion()).thenAnswer(stubbing);
case "getUnidentifiedAccessKey" -> when(updatedAccount.getUnidentifiedAccessKey()).thenAnswer(stubbing);
default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
}
}

View File

@@ -277,6 +277,7 @@ public class AuthHelper {
when(account.getPrimaryDevice()).thenReturn(device);
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}

View File

@@ -5,8 +5,7 @@
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.util.Optional;
public class TestPrincipal implements Principal {
@@ -21,7 +20,7 @@ public class TestPrincipal implements Principal {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
public static Optional<TestPrincipal> authenticatedTestPrincipal(final String name) {
return Optional.of(new TestPrincipal(name));
}
}

View File

@@ -176,7 +176,7 @@ class LoggingUnhandledExceptionMapperTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);

View File

@@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@@ -70,14 +69,12 @@ class WebSocketAccountAuthenticatorTest {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
} else {
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).isPresent());
}
}

View File

@@ -43,11 +43,11 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@@ -111,7 +111,7 @@ class WebSocketConnectionIntegrationTest {
clientReleaseManager = mock(ClientReleaseManager.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
}
@@ -137,7 +137,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@@ -159,14 +160,14 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -226,7 +227,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@@ -250,13 +252,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -296,7 +298,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly
scheduledExecutorService,
@@ -321,13 +324,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}

View File

@@ -56,6 +56,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@@ -67,9 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@@ -90,7 +89,6 @@ class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedDevice auth;
private UpgradeRequest upgradeRequest;
private MessagesManager messagesManager;
private ReceiptSender receiptSender;
@@ -104,7 +102,6 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedDevice(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@@ -122,8 +119,8 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class),
@@ -133,9 +130,9 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedDevice(account, device)));
ReusableAuth<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.ref().orElse(null));
Optional<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@@ -150,7 +147,7 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent());
assertFalse(account.isPresent());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(
@@ -174,7 +171,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@@ -191,7 +188,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(outgoingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -237,7 +234,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -320,7 +317,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@@ -337,7 +334,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(pendingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -364,7 +361,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getClientTimestamp()));
connection.stop();
@@ -377,7 +374,7 @@ class WebSocketConnectionTest {
final WebSocketConnection connection = webSocketConnection(client);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -385,7 +382,7 @@ class WebSocketConnectionTest {
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(
messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenAnswer(invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@@ -442,7 +439,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -490,7 +487,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -573,7 +570,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -581,7 +578,7 @@ class WebSocketConnectionTest {
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(messages))
.thenReturn(Flux.empty());
@@ -629,7 +626,7 @@ class WebSocketConnectionTest {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), account, device, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class));
}
@@ -642,7 +639,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -669,7 +666,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -725,7 +722,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -741,11 +738,11 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
connection.handleNewMessageAvailable();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true);
}
@Test
@@ -756,7 +753,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -773,7 +770,7 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
}
@Test
@@ -783,9 +780,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
@@ -812,9 +809,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
final WebSocketClient client = mock(WebSocketClient.class);
@@ -835,7 +832,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final int totalMessages = 1000;
@@ -884,7 +881,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final AtomicBoolean canceled = new AtomicBoolean();