Lifecycle management for Account objects reused accross websocket requests

This commit is contained in:
Ravi Khadiwala
2024-02-06 16:59:42 -06:00
committed by ravi-signal
parent 29ef3f0b41
commit 26ffa19f36
38 changed files with 1317 additions and 457 deletions

View File

@@ -188,6 +188,7 @@ import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@@ -812,7 +813,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager));

View File

@@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
@@ -108,8 +107,7 @@ public class AccountAuthenticator implements Authenticator<BasicCredentials, Aut
device.get(),
SaltedTokenHash.generateFor(basicCredentials.getPassword())); // new credentials have current version
}
return Optional.of(new AuthenticatedAccount(
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
return Optional.of(new AuthenticatedAccount(authenticatedAccount, device.get()));
}
return Optional.empty();

View File

@@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@@ -45,10 +44,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
this.accountsManager = accountsManager;
}
@VisibleForTesting
static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
@@ -60,10 +55,13 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
setAccount(requestEvent.getContainerRequest(), account));
}
}
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
containerRequest.setProperty(ACCOUNT_UUID, account.getUuid());
containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
}
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled());
}
@Override
@@ -75,25 +73,28 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> {
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(account -> {
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = account.devicesEnabled();
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.accountId(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else {
return Collections.emptyList();
}
}
}

View File

@@ -10,24 +10,24 @@ import java.util.function.Supplier;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final Account account;
private final Device device;
private final Supplier<Pair<Account, Device>> accountAndDevice;
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) {
this.accountAndDevice = accountAndDevice;
public AuthenticatedAccount(final Account account, final Device device) {
this.account = account;
this.device = device;
}
@Override
public Account getAccount() {
return accountAndDevice.get().first();
return account;
}
@Override
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
return device;
}
// Principal implementation

View File

@@ -0,0 +1,20 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Indicates that an endpoint changes the phone number and PNI keys associated with an account, and that
* any websockets associated with the account may need to be refreshed after a call to that endpoint.
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ChangesPhoneNumber {
}

View File

@@ -7,15 +7,42 @@ package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.core.SecurityContext;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
class ContainerRequestUtil {
static Optional<Account> getAuthenticatedAccount(final ContainerRequest request) {
private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Map<Byte, Boolean> devicesEnabled) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
buildDevicesEnabledMap(account));
}
}
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder
? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null);
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
return aaadh.getAccount();
}
return null;
})
.map(AccountInfo::fromAccount);
}
}

View File

@@ -7,40 +7,50 @@ package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String ACCOUNT_UUID =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
private final AccountsManager accountsManager;
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
.getAnnotation(ChangesPhoneNumber.class) == null) {
return;
}
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber()));
.ifPresent(account -> {
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
});
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber != null) {
final Optional<Account> maybeAuthenticatedAccount =
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest());
return maybeAuthenticatedAccount
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
} else {
if (initialNumber == null) {
return Collections.emptyList();
}
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
}

View File

@@ -24,7 +24,7 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
new AuthEnablementRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider());
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
@Override

View File

@@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesPhoneNumber;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
@@ -90,6 +91,7 @@ public class AccountControllerV2 {
@Path("/number")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesPhoneNumber
@Operation(summary = "Change number", description = "Changes a phone number for an existing account.")
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedAccount> {
private final AccountsManager accountsManager;
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public AuthenticatedAccount refresh(final AuthenticatedAccount oldAccount) {
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
return new AuthenticatedAccount(account, device);
}
@Override
public AuthenticatedAccount deepCopy(final AuthenticatedAccount authenticatedAccount) {
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedAccount.getAccount());
return new AuthenticatedAccount(
cloned,
cloned.getDevice(authenticatedAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new IllegalStateException(
"Could not find device from a clone of an account where the device was present")));
}
}

View File

@@ -5,10 +5,12 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.io.IOException;
class AccountUtil {
public class AccountUtil {
static Account cloneAccountAsNotStale(final Account account) {
try {

View File

@@ -1,35 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.util.Pair;
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
private Account account;
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
this.accountsManager = accountsManager;
}
@Override
public Pair<Account, Device> get() {
if (account.isStale()) {
account = accountsManager.getByAccountIdentifier(account.getUuid())
.orElseThrow(() -> new RuntimeException("Could not find account"));
device = account.getDevice(device.getId())
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
}
return new Pair<>(account, device);
}
}

View File

@@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.storage;
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException {
public class RefreshingAccountNotFoundException extends RuntimeException {
public RefreshingAccountAndDeviceNotFoundException(final String message) {
public RefreshingAccountNotFoundException(final String message) {
super(message);
}

View File

@@ -11,41 +11,40 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private static final AuthenticationResult<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED =
new AuthenticationResult<>(Optional.empty(), false);
private static final ReusableAuth<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private static final AuthenticationResult<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED =
new AuthenticationResult<>(Optional.empty(), true);
private static final ReusableAuth<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedAccount> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
final PrincipalSupplier<AuthenticatedAccount> principalSupplier) {
this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
}
@Override
public AuthenticationResult<AuthenticatedAccount> authenticate(final UpgradeRequest request)
public ReusableAuth<AuthenticatedAccount> authenticate(final UpgradeRequest request)
throws AuthenticationException {
try {
final AuthenticationResult<AuthenticatedAccount> authResultFromHeader =
authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION));
// the logic here is that if the `Authorization` header was set for the request,
// it takes the priority and we use the result of the header-based auth
// ignoring the result of the query-based auth.
if (authResultFromHeader.credentialsPresented()) {
return authResultFromHeader;
// If the `Authorization` header was set for the request it takes priority, and we use the result of the
// header-based auth ignoring the result of the query-based auth.
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (authHeader != null) {
return authenticatedAccountFromHeaderAuth(authHeader);
}
return authenticatedAccountFromQueryParams(request);
} catch (final Exception e) {
@@ -55,7 +54,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
final Map<String, List<String>> parameters = request.getParameterMap();
final List<String> usernames = parameters.get("login");
final List<String> passwords = parameters.get("password");
@@ -65,16 +64,19 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
return accountAuthenticator.authenticate(credentials)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
throws AuthenticationException {
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
}
return basicCredentialsFromAuthHeader(authHeader)
.map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true))
.flatMap(credentials -> accountAuthenticator.authenticate(credentials))
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}
}

View File

@@ -0,0 +1,279 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedAccount ACCOUNT = mock(AuthenticatedAccount.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedAccount> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedAccount copiedAccount = mock(AuthenticatedAccount.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedAccount account2 = mock(AuthenticatedAccount.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedAccount deepCopy = mock(AuthenticatedAccount.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedAccount refresh = mock(AuthenticatedAccount.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedAccount account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@@ -7,8 +7,6 @@ package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -30,7 +28,6 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedList;
@@ -76,7 +73,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@@ -132,38 +131,6 @@ class AuthEnablementRefreshRequirementProviderTest {
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
}
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
@@ -308,7 +275,7 @@ class AuthEnablementRefreshRequirementProviderTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
@@ -349,7 +316,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final Account account;
private final Device device;
private TestPrincipal(String name, final Account account, final Device device) {
private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
@@ -369,6 +336,11 @@ class AuthEnablementRefreshRequirementProviderTest {
public Device getAuthenticatedDevice() {
return device;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
}
@Path("/v1/test")

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ContainerRequestUtilTest {
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled();
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
}

View File

@@ -5,108 +5,292 @@
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.EnumSet;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.ws.rs.core.SecurityContext;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.client.Invocation;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest {
private PhoneNumberChangeRefreshRequirementProvider provider;
private Account account;
private RequestEvent requestEvent;
private ContainerRequest request;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach
void setUp() {
provider = new PhoneNumberChangeRefreshRequirementProvider();
void setUp() throws Exception {
reset(AUTHENTICATOR, CLIENT_PRESENCE, ACCOUNTS_MANAGER);
client = new WebSocketClient();
client.start();
account = mock(Account.class);
final Device device = mock(Device.class);
final UUID uuid = UUID.randomUUID();
account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(NUMBER);
when(account.getDevices()).thenReturn(List.of(device));
when(device.getId()).thenReturn(Device.PRIMARY_ID);
account2.setUuid(uuid);
account2.addDevice(authenticatedDevice);
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
request = mock(ContainerRequest.class);
}
final Map<String, Object> requestProperties = new HashMap<>();
@AfterEach
void tearDown() throws Exception {
client.stop();
}
doAnswer(invocation -> {
requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1));
return null;
}).when(request).setProperty(anyString(), any());
when(request.getProperty(anyString())).thenAnswer(
invocation -> requestProperties.get(invocation.getArgument(0, String.class)));
public static class TestApplication extends Application<Configuration> {
requestEvent = mock(RequestEvent.class);
when(requestEvent.getContainerRequest()).thenReturn(request);
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNoChange() {
setAuthenticatedAccount(request, account);
void handleRequestChangeAsyncEndpoint() throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
provider.handleRequestFiltered(requestEvent);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
// response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNumberChange() {
setAuthenticatedAccount(request, account);
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
provider.handleRequestFiltered(requestEvent);
when(account.getNumber()).thenReturn(CHANGED_NUMBER);
assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.PRIMARY_ID)), provider.handleRequestFinished(requestEvent));
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNoAuthenticatedAccount() {
final ContainerRequest request = mock(ContainerRequest.class);
setAuthenticatedAccount(request, null);
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
makeAnonymousRequest(protocol, "/test/not-authenticated");
when(requestEvent.getContainerRequest()).thenReturn(request);
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
provider.handleRequestFiltered(requestEvent);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) {
final SecurityContext securityContext = mock(SecurityContext.class);
when(mockRequest.getSecurityContext()).thenReturn(securityContext);
@Path("/test")
public static class TestController {
if (account != null) {
final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class);
@GET
@Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount);
when(authenticatedAccount.getAccount()).thenReturn(account);
} else {
when(securityContext.getUserPrincipal()).thenReturn(null);
@GET
@Path("/async-annotated")
@ChangesPhoneNumber
@ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
}
}

View File

@@ -87,7 +87,7 @@ class BasicCredentialAuthenticationInterceptorTest {
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
} else {
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.empty());

View File

@@ -39,7 +39,7 @@ class DirectoryControllerV2Test {
when(account.getUuid()).thenReturn(uuid);
final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken(
new AuthenticatedAccount(() -> new Pair<>(account, mock(Device.class)))).getEntity();
new AuthenticatedAccount(account, mock(Device.class))).getEntity();
assertEquals(credentials.username(), "d369bc712e2e0dd36258");
assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4");

View File

@@ -51,7 +51,10 @@ import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@@ -139,7 +142,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@@ -201,7 +204,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@@ -252,19 +255,6 @@ class MetricsRequestEventListenerTest {
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
}
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
@Path("/v1/test")
public static class TestResource {

View File

@@ -1,71 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Pair;
class RefreshingAccountAndDeviceSupplierTest {
@Test
void test() {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final byte deviceId = 2;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);
when(initialAccount.getUuid()).thenReturn(uuid);
when(initialDevice.getId()).thenReturn(deviceId);
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
when(accountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
return Optional.of(account);
});
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
initialAccount, deviceId, accountsManager);
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
when(initialAccount.isStale()).thenReturn(true);
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertNotSame(initialAccount, accountAndDevice.first());
assertNotSame(initialDevice, accountAndDevice.second());
assertEquals(uuid, accountAndDevice.first().getUuid());
}
}

View File

@@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
}

View File

@@ -0,0 +1,79 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
public class TestWebsocketListener implements WebSocketListener {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
private final WebSocketMessageFactory messageFactory;
public TestWebsocketListener() {
this.messageFactory = new ProtobufWebSocketMessageFactory();
}
@Override
public void onWebSocketConnect(final Session session) {
started.complete(session);
}
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
}
public CompletableFuture<WebSocketResponseMessage> sendRequest(
final String requestPath,
final String verb,
final List<String> headers,
final Optional<byte[]> body) {
return started.thenCompose(session -> {
final long id = requestId.incrementAndGet();
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
return future;
});
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());
} else {
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
}
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -57,6 +57,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
@@ -175,7 +176,8 @@ class LoggingUnhandledExceptionMapperTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -238,18 +240,4 @@ class LoggingUnhandledExceptionMapperTest {
throw new RuntimeException();
}
}
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
}

View File

@@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator = mock(AccountAuthenticator.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class)))));
.thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
@@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest {
@Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception {
final boolean expectInvalid) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) {
@@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest {
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator);
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
upgradeRequest);
final ReusableAuth<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(upgradeRequest);
assertEquals(expectAccount, result.getUser().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented());
assertEquals(expectAccount, result.ref().isPresent());
assertEquals(expectInvalid, result.invalidCredentialsProvided());
}
private static Stream<Arguments> testAuthenticate() {
@@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest {
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of(
// if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true),
Arguments.of(headerWithValidAuth, Map.of(), true, false),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true),
Arguments.of(null, paramsMapWithValidAuth, true, false),
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false)
);

View File

@@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly

View File

@@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@@ -101,7 +101,7 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
auth = new AuthenticatedAccount(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@@ -118,18 +118,19 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
ReusableAuth<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@@ -144,8 +145,8 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
assertFalse(account.credentialsPresented());
assertFalse(account.ref().isPresent());
assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(