Use refreshing AuthenticatedAccount for @Auth

This commit is contained in:
Chris Eager
2021-08-11 14:52:25 -05:00
committed by GitHub
parent b3e6a50dee
commit 31022aeb79
53 changed files with 1251 additions and 969 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -40,6 +40,7 @@ import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
@@ -52,6 +53,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@@ -90,13 +92,13 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
account,
device,
webSocketClient,
retrySchedulingExecutor);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor);
}
@After

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -24,6 +24,7 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials;
import io.lettuce.core.RedisException;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
@@ -39,7 +40,6 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import io.lettuce.core.RedisException;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Before;
@@ -48,6 +48,7 @@ import org.mockito.ArgumentMatchers;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
@@ -58,6 +59,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@@ -75,6 +77,7 @@ public class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedAccount auth;
private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender;
private ApnFallbackManager apnFallbackManager;
@@ -86,6 +89,7 @@ public class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class);
apnFallbackManager = mock(ApnFallbackManager.class);
@@ -94,35 +98,42 @@ public class WebSocketConnectionTest {
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class),
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages,
mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class),
retrySchedulingExecutor);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.<Account>empty());
.thenReturn(Optional.empty());
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{add(VALID_USER);}});
put("password", new LinkedList<String>() {{add(VALID_PASSWORD);}});
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{
put("login", new LinkedList<>() {{
add(VALID_USER);
}});
put("password", new LinkedList<>() {{
add(VALID_PASSWORD);
}});
}});
AuthenticationResult<Account> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(Account.class)).thenReturn(account.getUser().orElse(null));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{add(INVALID_USER);}});
put("password", new LinkedList<String>() {{add(INVALID_PASSWORD);}});
put("login", new LinkedList<String>() {{
add(INVALID_USER);
}});
put("password", new LinkedList<String>() {{
add(INVALID_PASSWORD);
}});
}});
account = webSocketAuthenticator.authenticate(upgradeRequest);
@@ -148,7 +159,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -184,12 +194,13 @@ public class WebSocketConnectionTest {
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor);
auth, device, client, retrySchedulingExecutor);
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertTrue(futures.size() == 3);
assertEquals(3, futures.size());
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
@@ -199,7 +210,7 @@ public class WebSocketConnectionTest {
futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid()));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender1"), eq(2222L));
connection.stop();
verify(client).close(anyInt(), anyString());
@@ -207,9 +218,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5_000L)
public void testOnlineSend() throws Exception {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@@ -219,7 +231,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false));
@@ -300,7 +312,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -336,11 +347,12 @@ public class WebSocketConnectionTest {
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor);
auth, device, client, retrySchedulingExecutor);
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(futures.size(), 2);
@@ -349,7 +361,7 @@ public class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender2"), eq(secondMessage.getTimestamp()));
connection.stop();
verify(client).close(anyInt(), anyString());
@@ -357,19 +369,21 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer(
(Answer<OutgoingMessageEntityList>) invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
threadWaiting.notifyAll();
@@ -418,9 +432,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testProcessStoredMessagesMultiplePages() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -428,8 +443,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@@ -463,7 +478,8 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -471,7 +487,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final List<OutgoingMessageEntity> messages = List.of(
createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage);
@@ -511,9 +528,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@@ -523,7 +541,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -540,10 +558,11 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
@@ -551,8 +570,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@@ -587,9 +606,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@@ -599,7 +619,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -619,9 +639,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@@ -631,7 +652,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -664,7 +685,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -689,20 +709,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(2, futures.size());
@@ -737,7 +761,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -762,20 +785,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(3, futures.size());
@@ -799,7 +826,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -808,17 +834,20 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenThrow(new RedisException("OH NO"));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer((Answer<ScheduledFuture<?>>) invocation -> {
invocation.getArgument(0, Runnable.class).run();
return mock(ScheduledFuture.class);
});
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
(Answer<ScheduledFuture<?>>) invocation -> {
invocation.getArgument(0, Runnable.class).run();
return mock(ScheduledFuture.class);
});
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketClient client = mock(WebSocketClient.class);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), anyLong(), any());
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class),
anyLong(), any());
verify(client).close(eq(1011), anyString());
}