mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-21 16:28:05 +01:00
Update noise-gRPC protocol errors
This commit is contained in:
@@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
@@ -16,12 +17,13 @@ import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.net.UnknownHostException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
@@ -36,16 +38,29 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
|
||||
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
protected ECKeyPair serverKeyPair;
|
||||
protected ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
static final String USER_AGENT = "Test/User-Agent";
|
||||
static final String ACCEPT_LANGUAGE = "test-lang";
|
||||
static final InetAddress REMOTE_ADDRESS;
|
||||
static {
|
||||
try {
|
||||
REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3});
|
||||
} catch (UnknownHostException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static class PongHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
@@ -93,7 +108,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
void setUp() {
|
||||
serverKeyPair = Curve.generateKeyPair();
|
||||
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||
embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler);
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
embeddedChannel = new EmbeddedChannel(
|
||||
new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair),
|
||||
noiseHandshakeCompleteHandler);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
@@ -110,8 +128,6 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||
}
|
||||
|
||||
protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair);
|
||||
|
||||
protected abstract CipherStatePair doHandshake() throws Throwable;
|
||||
|
||||
/**
|
||||
@@ -140,7 +156,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await();
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||
@@ -292,4 +308,19 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
|
||||
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void channelAttributes() throws Throwable {
|
||||
doHandshake();
|
||||
final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent();
|
||||
assertEquals(REMOTE_ADDRESS, event.remoteAddress());
|
||||
assertEquals(USER_AGENT, event.userAgent());
|
||||
assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage());
|
||||
}
|
||||
|
||||
protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() {
|
||||
return NoiseTunnelProtos.HandshakeInit.newBuilder()
|
||||
.setUserAgent(USER_AGENT)
|
||||
.setAcceptLanguage(ACCEPT_LANGUAGE);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +248,10 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
||||
assertEquals(
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,12 +272,35 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
||||
assertEquals(
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void clientNormalClosure() throws InterruptedException {
|
||||
final NoiseTunnelClient client = anonymous().build();
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
client.close();
|
||||
|
||||
// When we gracefully close the tunnel client, we should send an OK close frame
|
||||
final CloseFrameEvent closeFrame = client.closeFrameFuture().join();
|
||||
assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator());
|
||||
assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAnonymous() throws InterruptedException {
|
||||
try (final NoiseTunnelClient client = anonymous().build()) {
|
||||
|
||||
@@ -8,28 +8,23 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
@Override
|
||||
protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) {
|
||||
|
||||
return new NoiseAnonymousHandler(serverKeyPair);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Exception {
|
||||
return doHandshake(new byte[0]);
|
||||
return doHandshake(baseHandshakeInit().build().toByteArray());
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
|
||||
@@ -49,26 +44,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||
assertEquals(
|
||||
initiateHandshakeMessageLength,
|
||||
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
||||
final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage);
|
||||
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess());
|
||||
assertEquals(0, initiateHandshakeMessageBuf.refCnt());
|
||||
final NoiseHandshakeInit message = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.NK,
|
||||
Unpooled.wrappedBuffer(initiateHandshakeMessage));
|
||||
assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess());
|
||||
assertEquals(0, message.refCnt());
|
||||
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// Read responder handshake message
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
|
||||
new byte[responderHandshakeFrame.readableBytes()];
|
||||
responderHandshakeFrame.readBytes(responderHandshakeBytes);
|
||||
assertNotNull(responderHandshakeFrame);
|
||||
final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame);
|
||||
|
||||
// ephemeral key, empty encrypted payload AEAD tag
|
||||
final byte[] handshakeResponsePayload = new byte[32 + 16];
|
||||
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||
.build();
|
||||
|
||||
assertEquals(0,
|
||||
// ephemeral key, payload, AEAD tag
|
||||
assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length);
|
||||
|
||||
final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()];
|
||||
assertEquals(expectedHandshakeResponse.getSerializedSize(),
|
||||
clientHandshakeState.readMessage(
|
||||
responderHandshakeBytes, 0, responderHandshakeBytes.length,
|
||||
handshakeResponsePayload, 0));
|
||||
handshakeResponsePlaintext, 0));
|
||||
|
||||
assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext));
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
@@ -78,27 +82,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException {
|
||||
void handleCompleteHandshakeWithRequest() throws Exception {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes()));
|
||||
final byte[] handshakePlaintext = baseHandshakeInit()
|
||||
.setFastOpenRequest(ByteString.copyFromUtf8("ping")).build()
|
||||
.toByteArray();
|
||||
|
||||
final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext);
|
||||
final byte[] response = readNextPlaintext(cipherStatePair);
|
||||
assertArrayEquals(response, "pong".getBytes());
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0]));
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake());
|
||||
assertNull(readNextPlaintext(cipherStatePair));
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,18 +8,20 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
@@ -28,40 +30,24 @@ import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
private final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
@Override
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) {
|
||||
return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Throwable {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1);
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
return doHandshake(identityPayload(accountIdentifier, deviceId));
|
||||
@@ -71,7 +57,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
@@ -81,7 +67,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@@ -89,7 +78,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
@@ -97,15 +86,19 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final ByteBuffer bb = ByteBuffer.allocate(17 + 4);
|
||||
bb.put(identityPayload(accountIdentifier, deviceId));
|
||||
bb.put("ping".getBytes());
|
||||
final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||
.setFastOpenRequest(ByteString.copyFromUtf8("ping"))
|
||||
.build()
|
||||
.toByteArray();
|
||||
|
||||
final byte[] response = readNextPlaintext(doHandshake(bb.array()));
|
||||
assertEquals(response.length, 4);
|
||||
assertEquals(new String(response), "pong");
|
||||
final byte[] response = readNextPlaintext(doHandshake(handshakeInit));
|
||||
assertEquals(4, response.length);
|
||||
assertEquals("pong", new String(response));
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@@ -113,7 +106,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
void handleCompleteHandshakeMissingIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
|
||||
|
||||
@@ -121,7 +114,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
@@ -132,7 +125,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
void handleCompleteHandshakeMalformedIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
// no deviceId byte
|
||||
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
|
||||
@@ -142,7 +135,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
@@ -150,10 +143,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeUnrecognizedDevice() {
|
||||
void handleCompleteHandshakeUnrecognizedDevice() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
@@ -161,11 +154,13 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
doHandshake(
|
||||
identityPayload(accountIdentifier, deviceId),
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
@@ -173,10 +168,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakePublicKeyMismatch() {
|
||||
void handleCompleteHandshakePublicKeyMismatch() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
@@ -184,18 +179,21 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
doHandshake(
|
||||
identityPayload(accountIdentifier, deviceId),
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
||||
void handleInvalidExtraWrites()
|
||||
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
@@ -205,25 +203,23 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||
|
||||
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(
|
||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)));
|
||||
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
|
||||
final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.IK,
|
||||
Unpooled.wrappedBuffer(
|
||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
|
||||
assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess());
|
||||
|
||||
// While waiting for the public key, send another message
|
||||
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
|
||||
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
|
||||
assertInstanceOf(IllegalArgumentException.class, f.exceptionNow());
|
||||
|
||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// shouldn't return any response or error, we've already processed an error
|
||||
embeddedChannel.checkException();
|
||||
assertNull(embeddedChannel.outboundMessages().poll());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleOversizeHandshakeMessage() {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
||||
ByteBuffer.wrap(big)
|
||||
.put(UUIDUtil.toBytes(UUID.randomUUID()))
|
||||
@@ -231,6 +227,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleKeyLookupError() throws Throwable {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.failedFuture(new IOException()));
|
||||
assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
}
|
||||
|
||||
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||
@@ -262,15 +267,22 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
|
||||
return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK);
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
||||
|
||||
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage);
|
||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
|
||||
assertEquals(0, initiatorMessageFrame.refCnt());
|
||||
if (!await.isSuccess()) {
|
||||
final NoiseHandshakeInit initMessage = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.IK,
|
||||
Unpooled.wrappedBuffer(initiatorMessage));
|
||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await();
|
||||
assertEquals(0, initMessage.refCnt());
|
||||
if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||
throw await.cause();
|
||||
}
|
||||
|
||||
@@ -280,17 +292,27 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// rethrow if running the task caused an error
|
||||
embeddedChannel.checkException();
|
||||
// rethrow if running the task caused an error, and the caller isn't expecting an error
|
||||
if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||
embeddedChannel.checkException();
|
||||
}
|
||||
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||
new byte[serverStaticKeyMessageFrame.readableBytes()];
|
||||
serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes);
|
||||
final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(handshakeResponseFrame);
|
||||
final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame);
|
||||
|
||||
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
|
||||
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(expectedStatus)
|
||||
.build();
|
||||
|
||||
final byte[] actualHandshakeResponsePlaintext =
|
||||
readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes);
|
||||
|
||||
assertEquals(
|
||||
expectedHandshakeResponsePlaintext,
|
||||
NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext));
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
@@ -299,13 +321,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
return clientHandshakeState.split();
|
||||
}
|
||||
|
||||
private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) {
|
||||
return baseHandshakeInit()
|
||||
.setAci(UUIDUtil.toByteString(accountIdentifier))
|
||||
.setDeviceId(deviceId);
|
||||
}
|
||||
|
||||
private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
||||
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||
clientIdentityPayloadBuffer.put(deviceId);
|
||||
clientIdentityPayloadBuffer.flip();
|
||||
return clientIdentityPayloadBuffer.array();
|
||||
private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
||||
return identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||
.build()
|
||||
.toByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
|
||||
|
||||
public enum CloseReason {
|
||||
OK,
|
||||
SERVER_CLOSED,
|
||||
NOISE_ERROR,
|
||||
NOISE_HANDSHAKE_ERROR,
|
||||
AUTHENTICATION_ERROR,
|
||||
INTERNAL_SERVER_ERROR,
|
||||
UNKNOWN
|
||||
}
|
||||
@@ -27,27 +27,27 @@ public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeIniti
|
||||
CloseWebSocketFrame closeWebSocketFrame,
|
||||
CloseInitiator closeInitiator) {
|
||||
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
|
||||
case 4003 -> CloseReason.NOISE_ERROR;
|
||||
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||
case 4002 -> CloseReason.AUTHENTICATION_ERROR;
|
||||
case 4002 -> CloseReason.NOISE_ERROR;
|
||||
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||
case 1012 -> CloseReason.SERVER_CLOSED;
|
||||
case 1000 -> CloseReason.OK;
|
||||
default -> CloseReason.UNKNOWN;
|
||||
};
|
||||
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
|
||||
}
|
||||
|
||||
public static CloseFrameEvent fromNoiseDirectErrorFrame(
|
||||
NoiseDirectProtos.Error noiseDirectError,
|
||||
public static CloseFrameEvent fromNoiseDirectCloseFrame(
|
||||
NoiseDirectProtos.CloseReason noiseDirectCloseReason,
|
||||
CloseInitiator closeInitiator) {
|
||||
final CloseReason code = switch (noiseDirectError.getType()) {
|
||||
final CloseReason code = switch (noiseDirectCloseReason.getCode()) {
|
||||
case OK -> CloseReason.OK;
|
||||
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
|
||||
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
|
||||
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||
case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR;
|
||||
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
|
||||
};
|
||||
return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage());
|
||||
return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,10 @@ import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.SocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||
|
||||
/**
|
||||
@@ -31,12 +29,10 @@ import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final List<ChannelHandler> remoteHandlerStack;
|
||||
@Nullable
|
||||
private final AuthenticatedDevice authenticatedDevice;
|
||||
private final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||
|
||||
private final SocketAddress remoteServerAddress;
|
||||
// If provided, will be sent with the payload in the noise handshake
|
||||
private final byte[] fastOpenRequest;
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
@@ -44,13 +40,11 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
EstablishRemoteConnectionHandler(
|
||||
final List<ChannelHandler> remoteHandlerStack,
|
||||
@Nullable final AuthenticatedDevice authenticatedDevice,
|
||||
final SocketAddress remoteServerAddress,
|
||||
@Nullable byte[] fastOpenRequest) {
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit) {
|
||||
this.remoteHandlerStack = remoteHandlerStack;
|
||||
this.authenticatedDevice = authenticatedDevice;
|
||||
this.handshakeInit = handshakeInit;
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -72,16 +66,19 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
throws Exception {
|
||||
switch (event) {
|
||||
case ReadyForNoiseHandshakeEvent ignored ->
|
||||
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload()))
|
||||
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray()))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
case NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) -> {
|
||||
case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> {
|
||||
remoteContext.pipeline()
|
||||
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||
|
||||
// If there was a payload response on the handshake, write it back to our gRPC client
|
||||
fastResponse.ifPresent(plaintext ->
|
||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
|
||||
if (!handshakeResponse.getFastOpenResponse().isEmpty()) {
|
||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse
|
||||
.getFastOpenResponse()
|
||||
.asReadOnlyByteBuffer()));
|
||||
}
|
||||
|
||||
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
||||
pendingReads.forEach(localContext::fireChannelRead);
|
||||
@@ -120,17 +117,4 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
pendingReads.clear();
|
||||
}
|
||||
|
||||
private byte[] initialPayload() {
|
||||
if (authenticatedDevice == null) {
|
||||
return fastOpenRequest;
|
||||
}
|
||||
|
||||
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
|
||||
bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits());
|
||||
bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits());
|
||||
bb.put(authenticatedDevice.deviceId());
|
||||
bb.put(fastOpenRequest);
|
||||
bb.flip();
|
||||
return bb.array();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
@@ -12,4 +14,4 @@ import java.util.Optional;
|
||||
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
|
||||
* the server included a payload in the handshake response.
|
||||
*/
|
||||
public record NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) {}
|
||||
public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
@@ -7,9 +8,10 @@ import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||
|
||||
@@ -38,9 +40,13 @@ public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||
if (message instanceof ByteBuf frame) {
|
||||
try {
|
||||
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
|
||||
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
|
||||
final NoiseTunnelProtos.HandshakeResponse handshakeResponse =
|
||||
NoiseTunnelProtos.HandshakeResponse.parseFrom(payload);
|
||||
|
||||
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
|
||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse));
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new NoiseHandshakeException("Failed to parse handshake response");
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||
|
||||
/**
|
||||
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
||||
@@ -72,8 +74,10 @@ public class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||
ReferenceCountUtil.release(plaintext);
|
||||
}
|
||||
} else {
|
||||
// Clients only write ByteBufs or close the connection on errors, so any other message is unexpected
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) {
|
||||
// Clients only write ByteBufs or a close frame on errors, so any other message is unexpected
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.*;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
@@ -17,6 +27,13 @@ import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpClientCodec;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||
import io.netty.handler.ssl.SslContextBuilder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.SocketAddress;
|
||||
import java.net.URI;
|
||||
import java.security.cert.X509Certificate;
|
||||
@@ -26,26 +43,21 @@ import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||
import io.netty.handler.ssl.SslContextBuilder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import javax.net.ssl.SSLException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
|
||||
|
||||
import javax.net.ssl.SSLException;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
|
||||
private final CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture;
|
||||
private final CompletableFuture<Void> userCloseFuture;
|
||||
private final ServerBootstrap serverBootstrap;
|
||||
private Channel serverChannel;
|
||||
|
||||
@@ -66,11 +78,10 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
FramingType framingType = FramingType.WEBSOCKET;
|
||||
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
||||
HttpHeaders headers = new DefaultHttpHeaders();
|
||||
NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder();
|
||||
|
||||
boolean authenticated = false;
|
||||
ECKeyPair ecKeyPair = null;
|
||||
UUID accountIdentifier = null;
|
||||
byte deviceId = 0x00;
|
||||
boolean useTls;
|
||||
X509Certificate trustedServerCertificate = null;
|
||||
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
||||
@@ -86,8 +97,8 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
||||
this.authenticated = true;
|
||||
this.accountIdentifier = accountIdentifier;
|
||||
this.deviceId = deviceId;
|
||||
handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier));
|
||||
handshakeInit.setDeviceId(deviceId);
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
||||
return this;
|
||||
@@ -109,6 +120,16 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setUserAgent(final String userAgent) {
|
||||
handshakeInit.setUserAgent(userAgent);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setAcceptLanguage(final String acceptLanguage) {
|
||||
handshakeInit.setAcceptLanguage(acceptLanguage);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHeaders(final HttpHeaders headers) {
|
||||
this.headers = headers;
|
||||
return this;
|
||||
@@ -155,17 +176,41 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
handlers.add(new NoiseClientHandshakeHandler(helper));
|
||||
|
||||
// When the noise handshake completes we'll save the response from the server so client users can inspect it
|
||||
final UserEventFuture<NoiseClientHandshakeCompleteEvent> handshakeEventHandler =
|
||||
new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class);
|
||||
handlers.add(handshakeEventHandler);
|
||||
|
||||
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
|
||||
// information about why the connection was closed.
|
||||
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
|
||||
handlers.add(closeEventHandler);
|
||||
|
||||
// When the user closes the client, write a normal closure close frame
|
||||
final CompletableFuture<Void> userCloseFuture = new CompletableFuture<>();
|
||||
handlers.add(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext ctx) {
|
||||
userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) {
|
||||
case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE);
|
||||
case NOISE_DIRECT -> new NoiseDirectFrame(
|
||||
NoiseDirectFrame.FrameType.CLOSE,
|
||||
Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason
|
||||
.newBuilder()
|
||||
.setCode(NoiseDirectProtos.CloseReason.Code.OK)
|
||||
.build()
|
||||
.toByteArray()));
|
||||
})
|
||||
.addListener(ChannelFutureListener.CLOSE),
|
||||
ctx.executor());
|
||||
}
|
||||
});
|
||||
|
||||
final NoiseTunnelClient client =
|
||||
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||
handlers,
|
||||
authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null,
|
||||
remoteServerAddress,
|
||||
fastOpenRequest));
|
||||
handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build()));
|
||||
client.start();
|
||||
return client;
|
||||
}
|
||||
@@ -173,9 +218,13 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
|
||||
CompletableFuture<CloseFrameEvent> closeEventFuture,
|
||||
CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture,
|
||||
CompletableFuture<Void> userCloseFuture,
|
||||
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
||||
|
||||
this.userCloseFuture = userCloseFuture;
|
||||
this.closeEventFuture = closeEventFuture;
|
||||
this.handshakeEventFuture = handshakeEventFuture;
|
||||
this.serverBootstrap = new ServerBootstrap()
|
||||
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||
.channel(LocalServerChannel.class)
|
||||
@@ -194,10 +243,10 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
.addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
|
||||
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
|
||||
requestBufferedEvent.fastOpenRequest().release();
|
||||
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
|
||||
if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) {
|
||||
byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest);
|
||||
fastOpenRequest.release();
|
||||
ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes));
|
||||
}
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
@@ -216,7 +265,7 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) {
|
||||
if (cls.isInstance(evt)) {
|
||||
future.complete((T) evt);
|
||||
}
|
||||
@@ -236,6 +285,7 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
@Override
|
||||
public void close() throws InterruptedException {
|
||||
userCloseFuture.complete(null);
|
||||
serverChannel.close().await();
|
||||
}
|
||||
|
||||
@@ -246,6 +296,14 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
return closeEventFuture;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A future that completes when the noise handshake finishes
|
||||
*/
|
||||
public CompletableFuture<NoiseClientHandshakeCompleteEvent> getHandshakeEventFuture() {
|
||||
return handshakeEventFuture;
|
||||
}
|
||||
|
||||
|
||||
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
|
||||
return List.of(
|
||||
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
|
||||
@@ -259,12 +317,12 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
try {
|
||||
final NoiseDirectProtos.Error errorPayload =
|
||||
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
final NoiseDirectProtos.CloseReason closeReason =
|
||||
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
ctx.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER));
|
||||
CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER));
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
@@ -275,11 +333,11 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
||||
final NoiseDirectProtos.Error errorPayload =
|
||||
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
final NoiseDirectProtos.CloseReason errorPayload =
|
||||
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
ctx.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||
CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||
}
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
|
||||
@@ -126,11 +126,13 @@ class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelSe
|
||||
|
||||
final HttpHeaders headers = new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add("X-Forwarded-For", remoteAddress)
|
||||
.add("Accept-Language", acceptLanguage)
|
||||
.add("User-Agent", userAgent);
|
||||
.add("X-Forwarded-For", remoteAddress);
|
||||
|
||||
try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) {
|
||||
try (final NoiseTunnelClient client = anonymous()
|
||||
.setHeaders(headers)
|
||||
.setUserAgent(userAgent)
|
||||
.setAcceptLanguage(acceptLanguage)
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@ package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import java.util.concurrent.Executor;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
|
||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
@@ -17,7 +17,6 @@ import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.util.Attribute;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
@@ -32,13 +31,10 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
|
||||
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
@@ -84,9 +80,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||
|
||||
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
||||
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
|
||||
Curve.generateKeyPair(),
|
||||
RECOGNIZED_PROXY_SECRET),
|
||||
new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET),
|
||||
userEventRecordingHandler);
|
||||
|
||||
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
|
||||
@@ -94,22 +88,25 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends ChannelHandler> expectedHandlerClass) {
|
||||
void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
|
||||
|
||||
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
|
||||
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||
assertNotNull(init);
|
||||
assertEquals(init.getHandshakePattern(), pattern);
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||
return List.of(
|
||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class),
|
||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class));
|
||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK),
|
||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -141,13 +138,19 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
embeddedChannel.setRemoteAddress(remoteAddress);
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
|
||||
|
||||
assertEquals(expectedRemoteAddress,
|
||||
Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
|
||||
.map(Attribute::get)
|
||||
.map(RequestAttributes::remoteAddress)
|
||||
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||
assertEquals(
|
||||
expectedRemoteAddress,
|
||||
Optional.ofNullable(init)
|
||||
.map(NoiseHandshakeInit::getRemoteAddress)
|
||||
.orElse(null));
|
||||
if (expectedRemoteAddress == null) {
|
||||
assertThrows(IllegalStateException.class, embeddedChannel::checkException);
|
||||
} else {
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
}
|
||||
}
|
||||
|
||||
private static List<Arguments> getRemoteAddress() {
|
||||
|
||||
Reference in New Issue
Block a user