Use a custom redis pubsub implementation rather than Jedis.

// FREEBIE
This commit is contained in:
Moxie Marlinspike
2015-03-16 15:05:33 -07:00
parent e79861c30a
commit c7e0cc1158
26 changed files with 1254 additions and 292 deletions

View File

@@ -0,0 +1,127 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExternalResource;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
public class DispatchManagerTest {
private PubSubConnection pubSubConnection;
private RedisPubSubConnectionFactory socketFactory;
private DispatchManager dispatchManager;
private PubSubReplyInputStream pubSubReplyInputStream;
@Rule
public ExternalResource resource = new ExternalResource() {
@Override
protected void before() throws Throwable {
pubSubConnection = mock(PubSubConnection.class );
socketFactory = mock(RedisPubSubConnectionFactory.class);
pubSubReplyInputStream = new PubSubReplyInputStream();
when(socketFactory.connect()).thenReturn(pubSubConnection);
when(pubSubConnection.read()).thenAnswer(new Answer<PubSubReply>() {
@Override
public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable {
return pubSubReplyInputStream.read();
}
});
dispatchManager = new DispatchManager(socketFactory, Optional.<DispatchChannel>absent());
dispatchManager.start();
}
@Override
protected void after() {
}
};
@Test
public void testConnect() {
verify(socketFactory).connect();
}
@Test
public void testSubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
}
@Test
public void testSubscribeUnsubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
dispatchManager.unsubscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo"));
}
@Test
public void testMessages() throws IOException {
DispatchChannel fooChannel = mock(DispatchChannel.class);
DispatchChannel barChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", fooChannel);
dispatchManager.subscribe("bar", barChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.<byte[]>absent()));
verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar"));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes())));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes())));
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture());
assertArrayEquals("hello".getBytes(), captor.getValue());
verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture());
assertArrayEquals("there".getBytes(), captor.getValue());
}
private static class PubSubReplyInputStream {
private final List<PubSubReply> pubSubReplyList = new LinkedList<>();
public synchronized PubSubReply read() {
try {
while (pubSubReplyList.isEmpty()) wait();
return pubSubReplyList.remove(0);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public synchronized void write(PubSubReply pubSubReply) {
pubSubReplyList.add(pubSubReply);
notifyAll();
}
}
}

View File

@@ -0,0 +1,263 @@
package org.whispersystems.dispatch.redis;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.*;
public class PubSubConnectionTest {
private static final String REPLY = "*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"abcde\r\n" +
":1\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"fghij\r\n" +
":2\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"klmno\r\n" +
":2\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"abcde\r\n" +
"$10\r\n" +
"1234567890\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"klmno\r\n" +
"$10\r\n" +
"0987654321\r\n";
@Test
public void testSubscribe() throws IOException {
// ByteChannel byteChannel = mock(ByteChannel.class);
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.subscribe("foobar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes());
}
@Test
public void testUnsubscribe() throws IOException {
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.unsubscribe("bazbar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes());
}
@Test
public void testTricklyResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testFullResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new FullInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testRandomLengthResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new RandomInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException {
InputStream result = mock(InputStream.class);
when(result.read()).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
return stub.read();
}
});
when(result.read(any(byte[].class))).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[])invocationOnMock.getArguments()[0];
return stub.read(buffer, 0, buffer.length);
}
});
when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[]) invocationOnMock.getArguments()[0];
int offset = (int) invocationOnMock.getArguments()[1];
int length = (int) invocationOnMock.getArguments()[2];
return stub.read(buffer, offset, length);
}
});
return result;
}
private void readResponses(PubSubConnection pubSubConnection) throws Exception {
PubSubReply reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "abcde");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "fghij");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "klmno");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "abcde");
assertArrayEquals(reply.getContent().get(), "1234567890".getBytes());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "klmno");
assertArrayEquals(reply.getContent().get(), "0987654321".getBytes());
}
private interface MockInputStream {
public int read();
public int read(byte[] input, int offset, int length);
}
private static class TrickleInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private TrickleInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
input[offset] = data[index++];
return 1;
}
}
private static class FullInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private FullInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
int amount = Math.min(data.length - index, length);
System.arraycopy(data, index, input, offset, amount);
index += length;
return amount;
}
}
private static class RandomInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private RandomInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
try {
int maxCopy = Math.min(data.length - index, length);
int randomCopy = SecureRandom.getInstance("SHA1PRNG").nextInt(maxCopy) + 1;
int copyAmount = Math.min(maxCopy, randomCopy);
System.arraycopy(data, index, input, offset, copyAmount);
index += copyAmount;
return copyAmount;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}
}

View File

@@ -0,0 +1,51 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class ArrayReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new ArrayReplyHeader(null);
}
@Test(expected = IOException.class)
public void testBadPrefix() throws IOException {
new ArrayReplyHeader(":3");
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new ArrayReplyHeader("");
}
@Test(expected = IOException.class)
public void testTruncated() throws IOException {
new ArrayReplyHeader("*");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new ArrayReplyHeader("*ABC");
}
@Test
public void testValid() throws IOException {
assertEquals(4, new ArrayReplyHeader("*4").getElementCount());
}
}

View File

@@ -0,0 +1,36 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class IntReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new IntReply(null);
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new IntReply("");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new IntReply(":A");
}
@Test(expected = IOException.class)
public void testBadFormat() throws IOException {
new IntReply("*");
}
@Test
public void testValid() throws IOException {
assertEquals(23, new IntReply(":23").getValue());
}
}

View File

@@ -0,0 +1,47 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class StringReplyHeaderTest {
@Test
public void testNull() {
try {
new StringReplyHeader(null);
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadNumber() {
try {
new StringReplyHeader("$100A");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadPrefix() {
try {
new StringReplyHeader("*");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testValid() throws IOException {
assertEquals(1000, new StringReplyHeader("$1000").getStringLength());
}
}

View File

@@ -44,45 +44,20 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMe
public class WebSocketConnectionTest {
// private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCloseExisting() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class );
WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
Account account = mock(Account.class );
Device device = mock(Device.class );
when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14157777777");
when(device.getId()).thenReturn(1L);
connectListener.onWebSocketConnect(sessionContext);
ArgumentCaptor<PubSubProtos.PubSubMessage> message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture());
assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE);
}
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
@@ -98,10 +73,6 @@ public class WebSocketConnectionTest {
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
//
// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
@@ -114,13 +85,6 @@ public class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
//
// controller.onWebSocketConnect(session);
// verify(session, never()).close();
// verify(session, never()).close(any(CloseStatus.class));
// verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
@@ -128,15 +92,6 @@ public class WebSocketConnectionTest {
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.isPresent());
// when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
//
// WebSocketClient client = mock(WebSocketClient.class);
// when(sessionContext.getClient()).thenReturn(client);
//
// connectListener.onWebSocketConnect(sessionContext);
//
// verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
// verify(client).close(eq(4001), anyString());
}
@Test
@@ -183,12 +138,11 @@ public class WebSocketConnectionTest {
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
pubSubManager, account, device, client);
account, device, client);
connection.onConnected();
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertTrue(futures.size() == 3);
@@ -205,11 +159,10 @@ public class WebSocketConnectionTest {
add(createMessage("sender2", 3333, false, "third"));
}};
// verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(OutgoingMessageSignal.class));
connection.onConnectionLost();
verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
}
private OutgoingMessageSignal createMessage(String sender, long timestamp, boolean receipt, String content) {