Enable header-based auth for WebSocket connections

This commit is contained in:
Sergey Skrobotov
2023-09-25 11:28:23 -07:00
parent a263611746
commit d0fdae3df7
8 changed files with 147 additions and 85 deletions

View File

@@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
@@ -30,7 +29,6 @@ import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
@@ -41,6 +39,7 @@ import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
class BasicCredentialAuthenticationInterceptorTest {
@@ -122,8 +121,10 @@ class BasicCredentialAuthenticationInterceptorTest {
malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect");
final Metadata structurallyValidCredentialHeaders = new Metadata();
structurallyValidCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
UUID.randomUUID() + ":" + RandomStringUtils.randomAlphanumeric(16));
structurallyValidCredentialHeaders.put(
BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16))
);
return Stream.of(
Arguments.of(new Metadata(), true, false),
@@ -132,22 +133,4 @@ class BasicCredentialAuthenticationInterceptorTest {
Arguments.of(structurallyValidCredentialHeaders, true, true)
);
}
@Test
void extractBasicCredentials() {
final String username = UUID.randomUUID().toString();
final String password = RandomStringUtils.random(16);
final BasicCredentials basicCredentials =
BasicCredentialAuthenticationInterceptor.extractBasicCredentials(username + ":" + password);
assertEquals(username, basicCredentials.getUsername());
assertEquals(password, basicCredentials.getPassword());
}
@Test
void extractBasicCredentialsIllegalArgument() {
assertThrows(IllegalArgumentException.class,
() -> BasicCredentialAuthenticationInterceptor.extractBasicCredentials("This does not include a password"));
}
}

View File

@@ -6,17 +6,18 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
@@ -26,6 +27,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@@ -33,9 +35,12 @@ class WebSocketAccountAuthenticatorTest {
private static final String VALID_USER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("NZ"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final String VALID_PASSWORD = "valid";
private static final String INVALID_USER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("AU"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final String INVALID_PASSWORD = "invalid";
private AccountAuthenticator accountAuthenticator;
@@ -57,10 +62,16 @@ class WebSocketAccountAuthenticatorTest {
@ParameterizedTest
@MethodSource
void testAuthenticate(final Map<String, List<String>> upgradeRequestParameters, final boolean expectAccount,
final boolean expectRequired) throws Exception {
void testAuthenticate(
@Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator);
@@ -68,20 +79,34 @@ class WebSocketAccountAuthenticatorTest {
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
upgradeRequest);
if (expectAccount) {
assertTrue(result.getUser().isPresent());
} else {
assertTrue(result.getUser().isEmpty());
}
assertEquals(expectRequired, result.isRequired());
assertEquals(expectAccount, result.getUser().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented());
}
private static Stream<Arguments> testAuthenticate() {
final Map<String, List<String>> paramsMapWithValidAuth =
Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD));
final Map<String, List<String>> paramsMapWithInvalidAuth =
Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD));
final String headerWithValidAuth =
HeaderUtils.basicAuthHeader(VALID_USER, VALID_PASSWORD);
final String headerWithInvalidAuth =
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of(
Arguments.of(Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD)), true, true),
Arguments.of(Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD)), false, true),
Arguments.of(Map.of(), false, false)
// if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true),
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false)
);
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -142,7 +142,7 @@ class WebSocketConnectionTest {
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
assertFalse(account.isRequired());
assertFalse(account.credentialsPresented());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(