mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-26 05:18:06 +01:00
Refactor WebSocket support to use Redis for pubsub communication.
This commit is contained in:
@@ -29,6 +29,9 @@ public class Account implements Serializable {
|
||||
|
||||
public static final int MEMCACHE_VERION = 2;
|
||||
|
||||
@JsonIgnore
|
||||
private long id;
|
||||
|
||||
@JsonProperty
|
||||
private String number;
|
||||
|
||||
@@ -48,6 +51,14 @@ public class Account implements Serializable {
|
||||
this.supportsSms = supportsSms;
|
||||
}
|
||||
|
||||
public long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public Optional<Device> getAuthenticatedDevice() {
|
||||
return authenticatedDevice;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,10 @@ public abstract class Accounts {
|
||||
throws SQLException
|
||||
{
|
||||
try {
|
||||
return mapper.readValue(resultSet.getString(DATA), Account.class);
|
||||
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
|
||||
account.setId(resultSet.getLong(ID));
|
||||
|
||||
return account;
|
||||
} catch (IOException e) {
|
||||
throw new SQLException(e);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
public interface PubSubListener {
|
||||
|
||||
public void onPubSubMessage(PubSubMessage outgoingMessage);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import redis.clients.jedis.Jedis;
|
||||
import redis.clients.jedis.JedisPool;
|
||||
import redis.clients.jedis.JedisPubSub;
|
||||
|
||||
public class PubSubManager {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
|
||||
private final ObjectMapper mapper = new ObjectMapper();
|
||||
private final SubscriptionListener baseListener = new SubscriptionListener();
|
||||
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
|
||||
|
||||
private final JedisPool jedisPool;
|
||||
private boolean subscribed = false;
|
||||
|
||||
public PubSubManager(final JedisPool jedisPool) {
|
||||
this.jedisPool = jedisPool;
|
||||
initializePubSubWorker();
|
||||
waitForSubscription();
|
||||
}
|
||||
|
||||
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
listeners.put(address, listener);
|
||||
baseListener.subscribe(address.toString());
|
||||
}
|
||||
|
||||
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
if (listeners.get(address) == listener) {
|
||||
listeners.remove(address);
|
||||
baseListener.unsubscribe(address.toString());
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
|
||||
try {
|
||||
String serialized = mapper.writeValueAsString(message);
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
return jedis.publish(address.toString(), serialized) != 0;
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void waitForSubscription() {
|
||||
try {
|
||||
while (!subscribed) {
|
||||
wait();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void initializePubSubWorker() {
|
||||
new Thread("PubSubListener") {
|
||||
@Override
|
||||
public void run() {
|
||||
for (;;) {
|
||||
Jedis jedis = null;
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
|
||||
logger.warn("**** Unsubscribed from holding channel!!! ******");
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start();
|
||||
|
||||
new Thread("PubSubKeepAlive") {
|
||||
@Override
|
||||
public void run() {
|
||||
for (;;) {
|
||||
try {
|
||||
Thread.sleep(20000);
|
||||
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
|
||||
} catch (InterruptedException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start();
|
||||
}
|
||||
|
||||
private class SubscriptionListener extends JedisPubSub {
|
||||
|
||||
@Override
|
||||
public void onMessage(String channel, String message) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
PubSubListener listener;
|
||||
|
||||
synchronized (PubSubManager.this) {
|
||||
listener = listeners.get(address);
|
||||
}
|
||||
|
||||
if (listener != null) {
|
||||
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Address", e);
|
||||
} catch (IOException e) {
|
||||
logger.warn("IOE", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onPMessage(String s, String s2, String s3) {
|
||||
logger.warn("Received PMessage!");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(String channel, int count) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
|
||||
synchronized (PubSubManager.this) {
|
||||
subscribed = true;
|
||||
PubSubManager.this.notifyAll();
|
||||
}
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Weird address", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onUnsubscribe(String s, int i) {}
|
||||
|
||||
@Override
|
||||
public void onPUnsubscribe(String s, int i) {}
|
||||
|
||||
@Override
|
||||
public void onPSubscribe(String s, int i) {}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class PubSubMessage {
|
||||
|
||||
public static final int TYPE_QUERY_DB = 1;
|
||||
public static final int TYPE_DELIVER = 2;
|
||||
|
||||
@JsonProperty
|
||||
private int type;
|
||||
|
||||
@JsonProperty
|
||||
private String contents;
|
||||
|
||||
public PubSubMessage() {}
|
||||
|
||||
public PubSubMessage(int type, String contents) {
|
||||
this.type = type;
|
||||
this.contents = contents;
|
||||
}
|
||||
|
||||
public int getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public String getContents() {
|
||||
return contents;
|
||||
}
|
||||
}
|
||||
@@ -18,23 +18,49 @@ package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
|
||||
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class StoredMessageManager {
|
||||
StoredMessages storedMessages;
|
||||
public StoredMessageManager(StoredMessages storedMessages) {
|
||||
|
||||
private final StoredMessages storedMessages;
|
||||
private final PubSubManager pubSubManager;
|
||||
|
||||
public StoredMessageManager(StoredMessages storedMessages, PubSubManager pubSubManager) {
|
||||
this.storedMessages = storedMessages;
|
||||
this.pubSubManager = pubSubManager;
|
||||
}
|
||||
|
||||
public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
|
||||
public void storeMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
|
||||
throws CryptoEncodingException
|
||||
{
|
||||
storedMessages.insert(device.getId(), outgoingMessage.serialize());
|
||||
storeMessage(account, device, outgoingMessage.serialize());
|
||||
}
|
||||
|
||||
public List<String> getStoredMessage(Device device) {
|
||||
return storedMessages.getMessagesForAccountId(device.getId());
|
||||
public void storeMessages(Account account, Device device, List<String> serializedMessages) {
|
||||
for (String serializedMessage : serializedMessages) {
|
||||
storeMessage(account, device, serializedMessage);
|
||||
}
|
||||
}
|
||||
|
||||
private void storeMessage(Account account, Device device, String serializedMessage) {
|
||||
if (device.getFetchesMessages()) {
|
||||
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
|
||||
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage);
|
||||
|
||||
if (!pubSubManager.publish(address, pubSubMessage)) {
|
||||
storedMessages.insert(account.getId(), device.getId(), serializedMessage);
|
||||
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
storedMessages.insert(account.getId(), device.getId(), serializedMessage);
|
||||
}
|
||||
|
||||
public List<String> getOutgoingMessages(Account account, Device device) {
|
||||
return storedMessages.getMessagesForDevice(account.getId(), device.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import org.skife.jdbi.v2.sqlobject.Bind;
|
||||
import org.skife.jdbi.v2.sqlobject.SqlBatch;
|
||||
import org.skife.jdbi.v2.sqlobject.SqlQuery;
|
||||
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
|
||||
|
||||
@@ -24,9 +25,12 @@ import java.util.List;
|
||||
|
||||
public interface StoredMessages {
|
||||
|
||||
@SqlUpdate("INSERT INTO stored_messages (destination_id, encrypted_message) VALUES (:destination_id, :encrypted_message)")
|
||||
void insert(@Bind("destination_id") long destinationAccountId, @Bind("encrypted_message") String encryptedOutgoingMessage);
|
||||
@SqlUpdate("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
|
||||
void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") String encryptedOutgoingMessage);
|
||||
|
||||
@SqlQuery("SELECT encrypted_message FROM stored_messages WHERE destination_id = :account_id")
|
||||
List<String> getMessagesForAccountId(@Bind("account_id") long accountId);
|
||||
@SqlBatch("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
|
||||
void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") List<String> encryptedOutgoingMessages);
|
||||
|
||||
@SqlQuery("DELETE FROM messages WHERE account_id = :account_id AND device_id = :device_id RETURNING encrypted_message")
|
||||
List<String> getMessagesForDevice(@Bind("account_id") long accountId, @Bind("device_id") long deviceId);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user