Refactor WebSocket support to use Redis for pubsub communication.

This commit is contained in:
Moxie Marlinspike
2014-01-24 12:33:40 -08:00
parent 519f982604
commit 7bb505db4c
19 changed files with 670 additions and 23 deletions

View File

@@ -29,6 +29,9 @@ public class Account implements Serializable {
public static final int MEMCACHE_VERION = 2;
@JsonIgnore
private long id;
@JsonProperty
private String number;
@@ -48,6 +51,14 @@ public class Account implements Serializable {
this.supportsSms = supportsSms;
}
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice;
}

View File

@@ -95,7 +95,10 @@ public abstract class Accounts {
throws SQLException
{
try {
return mapper.readValue(resultSet.getString(DATA), Account.class);
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setId(resultSet.getLong(ID));
return account;
} catch (IOException e) {
throw new SQLException(e);
}

View File

@@ -0,0 +1,7 @@
package org.whispersystems.textsecuregcm.storage;
public interface PubSubListener {
public void onPubSubMessage(PubSubMessage outgoingMessage);
}

View File

@@ -0,0 +1,157 @@
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPubSub;
public class PubSubManager {
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = new ObjectMapper();
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
private final JedisPool jedisPool;
private boolean subscribed = false;
public PubSubManager(final JedisPool jedisPool) {
this.jedisPool = jedisPool;
initializePubSubWorker();
waitForSubscription();
}
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
listeners.put(address, listener);
baseListener.subscribe(address.toString());
}
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
if (listeners.get(address) == listener) {
listeners.remove(address);
baseListener.unsubscribe(address.toString());
}
}
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
try {
String serialized = mapper.writeValueAsString(message);
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
return jedis.publish(address.toString(), serialized) != 0;
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
} catch (JsonProcessingException e) {
throw new AssertionError(e);
}
}
private synchronized void waitForSubscription() {
try {
while (!subscribed) {
wait();
}
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
private void initializePubSubWorker() {
new Thread("PubSubListener") {
@Override
public void run() {
for (;;) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
logger.warn("**** Unsubscribed from holding channel!!! ******");
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
}
}.start();
new Thread("PubSubKeepAlive") {
@Override
public void run() {
for (;;) {
try {
Thread.sleep(20000);
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
}
}.start();
}
private class SubscriptionListener extends JedisPubSub {
@Override
public void onMessage(String channel, String message) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
PubSubListener listener;
synchronized (PubSubManager.this) {
listener = listeners.get(address);
}
if (listener != null) {
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Address", e);
} catch (IOException e) {
logger.warn("IOE", e);
}
}
@Override
public void onPMessage(String s, String s2, String s3) {
logger.warn("Received PMessage!");
}
@Override
public void onSubscribe(String channel, int count) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Weird address", e);
}
}
@Override
public void onUnsubscribe(String s, int i) {}
@Override
public void onPUnsubscribe(String s, int i) {}
@Override
public void onPSubscribe(String s, int i) {}
}
}

View File

@@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
@JsonIgnoreProperties(ignoreUnknown = true)
public class PubSubMessage {
public static final int TYPE_QUERY_DB = 1;
public static final int TYPE_DELIVER = 2;
@JsonProperty
private int type;
@JsonProperty
private String contents;
public PubSubMessage() {}
public PubSubMessage(int type, String contents) {
this.type = type;
this.contents = contents;
}
public int getType() {
return type;
}
public String getContents() {
return contents;
}
}

View File

@@ -18,23 +18,49 @@ package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.List;
public class StoredMessageManager {
StoredMessages storedMessages;
public StoredMessageManager(StoredMessages storedMessages) {
private final StoredMessages storedMessages;
private final PubSubManager pubSubManager;
public StoredMessageManager(StoredMessages storedMessages, PubSubManager pubSubManager) {
this.storedMessages = storedMessages;
this.pubSubManager = pubSubManager;
}
public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
public void storeMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws CryptoEncodingException
{
storedMessages.insert(device.getId(), outgoingMessage.serialize());
storeMessage(account, device, outgoingMessage.serialize());
}
public List<String> getStoredMessage(Device device) {
return storedMessages.getMessagesForAccountId(device.getId());
public void storeMessages(Account account, Device device, List<String> serializedMessages) {
for (String serializedMessage : serializedMessages) {
storeMessage(account, device, serializedMessage);
}
}
private void storeMessage(Account account, Device device, String serializedMessage) {
if (device.getFetchesMessages()) {
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage);
if (!pubSubManager.publish(address, pubSubMessage)) {
storedMessages.insert(account.getId(), device.getId(), serializedMessage);
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
}
return;
}
storedMessages.insert(account.getId(), device.getId(), serializedMessage);
}
public List<String> getOutgoingMessages(Account account, Device device) {
return storedMessages.getMessagesForDevice(account.getId(), device.getId());
}
}

View File

@@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlBatch;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
@@ -24,9 +25,12 @@ import java.util.List;
public interface StoredMessages {
@SqlUpdate("INSERT INTO stored_messages (destination_id, encrypted_message) VALUES (:destination_id, :encrypted_message)")
void insert(@Bind("destination_id") long destinationAccountId, @Bind("encrypted_message") String encryptedOutgoingMessage);
@SqlUpdate("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") String encryptedOutgoingMessage);
@SqlQuery("SELECT encrypted_message FROM stored_messages WHERE destination_id = :account_id")
List<String> getMessagesForAccountId(@Bind("account_id") long accountId);
@SqlBatch("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") List<String> encryptedOutgoingMessages);
@SqlQuery("DELETE FROM messages WHERE account_id = :account_id AND device_id = :device_id RETURNING encrypted_message")
List<String> getMessagesForDevice(@Bind("account_id") long accountId, @Bind("device_id") long deviceId);
}