Use a consistent provisioning address

This commit is contained in:
Jon Chambers
2024-10-01 13:34:37 -04:00
committed by GitHub
parent b284e95394
commit 26503dffdf
3 changed files with 74 additions and 8 deletions

View File

@@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
@ExtendWith(DropwizardExtensionsSupport.class)
class ProvisioningControllerTest {
@@ -64,7 +65,7 @@ class ProvisioningControllerTest {
@Test
void sendProvisioningMessage() {
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true);
@@ -84,7 +85,7 @@ class ProvisioningControllerTest {
@Test
void sendProvisioningMessageRateLimited() throws RateLimitExceededException {
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
doThrow(new RateLimitExceededException(Duration.ZERO))

View File

@@ -0,0 +1,63 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
class ProvisioningConnectListenerTest {
private ProvisioningManager provisioningManager;
private ProvisioningConnectListener provisioningConnectListener;
@BeforeEach
void setUp() {
provisioningManager = mock(ProvisioningManager.class);
provisioningConnectListener = new ProvisioningConnectListener(provisioningManager);
}
@Test
void onWebSocketConnect() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
provisioningConnectListener.onWebSocketConnect(context);
context.notifyClosed(1000, "Test");
final ArgumentCaptor<String> addListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
final ArgumentCaptor<String> removeListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> sendAddressCaptor =
ArgumentCaptor.forClass(Optional.class);
verify(provisioningManager).addListener(addListenerProvisioningAddressCaptor.capture(), any());
verify(provisioningManager).removeListener(removeListenerProvisioningAddressCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/v1/address"), any(), sendAddressCaptor.capture());
final String sentProvisioningAddress = sendAddressCaptor.getValue()
.map(provisioningAddressBytes -> {
try {
return MessageProtos.ProvisioningAddress.parseFrom(provisioningAddressBytes);
} catch (final InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.map(MessageProtos.ProvisioningAddress::getAddress)
.orElseThrow();
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
}
}