mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-25 13:48:06 +01:00
Fix for PubSub channel.
1) Create channels based on numbers rather than DB row ids. 2) Ensure that stored messages are cleared at reregistration time.
This commit is contained in:
@@ -172,7 +172,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||
deviceAuthenticator,
|
||||
Device.class, "WhisperServer"));
|
||||
|
||||
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
|
||||
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, storedMessages));
|
||||
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
|
||||
environment.jersey().register(new DirectoryController(rateLimiters, directory));
|
||||
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
|
||||
|
||||
@@ -34,8 +34,10 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.StoredMessages;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import org.whispersystems.textsecuregcm.util.VerificationCode;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import javax.validation.Valid;
|
||||
import javax.ws.rs.Consumes;
|
||||
@@ -65,16 +67,19 @@ public class AccountController {
|
||||
private final AccountsManager accounts;
|
||||
private final RateLimiters rateLimiters;
|
||||
private final SmsSender smsSender;
|
||||
private final StoredMessages storedMessages;
|
||||
|
||||
public AccountController(PendingAccountsManager pendingAccounts,
|
||||
AccountsManager accounts,
|
||||
RateLimiters rateLimiters,
|
||||
SmsSender smsSenderFactory)
|
||||
SmsSender smsSenderFactory,
|
||||
StoredMessages storedMessages)
|
||||
{
|
||||
this.pendingAccounts = pendingAccounts;
|
||||
this.accounts = accounts;
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.smsSender = smsSenderFactory;
|
||||
this.storedMessages = storedMessages;
|
||||
}
|
||||
|
||||
@Timed
|
||||
@@ -153,7 +158,7 @@ public class AccountController {
|
||||
account.addDevice(device);
|
||||
|
||||
accounts.create(account);
|
||||
|
||||
storedMessages.clear(new WebsocketAddress(number, Device.MASTER_ID));
|
||||
pendingAccounts.remove(number);
|
||||
|
||||
logger.debug("Stored device...");
|
||||
|
||||
@@ -94,7 +94,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
||||
|
||||
this.account = account.get();
|
||||
this.device = account.get().getAuthenticatedDevice().get();
|
||||
this.address = new WebsocketAddress(this.account.getId(), this.device.getId());
|
||||
this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId());
|
||||
this.session = session;
|
||||
|
||||
this.session.setIdleTimeout(10 * 60 * 1000);
|
||||
@@ -148,7 +148,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
||||
pushSender.sendMessage(account, device, remainingMessage);
|
||||
} catch (NotPushRegisteredException | TransientPushFailureException e) {
|
||||
logger.warn("onWebSocketClose", e);
|
||||
storedMessages.insert(account.getId(), device.getId(), remainingMessage);
|
||||
storedMessages.insert(address, remainingMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,7 +208,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
||||
}
|
||||
|
||||
private void handleQueryDatabase() {
|
||||
List<PendingMessage> messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
|
||||
List<PendingMessage> messages = storedMessages.getMessagesForDevice(address);
|
||||
|
||||
for (PendingMessage message : messages) {
|
||||
handleDeliverOutgoingMessage(message);
|
||||
|
||||
@@ -103,16 +103,16 @@ public class APNSender implements Managed {
|
||||
throws TransientPushFailureException
|
||||
{
|
||||
try {
|
||||
String serializedPendingMessage = mapper.writeValueAsString(message);
|
||||
String serializedPendingMessage = mapper.writeValueAsString(message);
|
||||
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
|
||||
|
||||
if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()),
|
||||
new PubSubMessage(PubSubMessage.TYPE_DELIVER,
|
||||
serializedPendingMessage)))
|
||||
if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER,
|
||||
serializedPendingMessage)))
|
||||
{
|
||||
websocketMeter.mark();
|
||||
} else {
|
||||
memcacheSet(registrationId, account.getNumber());
|
||||
storedMessages.insert(account.getId(), device.getId(), message);
|
||||
storedMessages.insert(websocketAddress, message);
|
||||
|
||||
if (!message.isReceipt()) {
|
||||
sendPush(registrationId, serializedPendingMessage);
|
||||
|
||||
@@ -57,14 +57,14 @@ public class WebsocketSender {
|
||||
public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
|
||||
try {
|
||||
String serialized = mapper.writeValueAsString(pendingMessage);
|
||||
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
|
||||
WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
|
||||
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
|
||||
|
||||
if (pubSubManager.publish(address, pubSubMessage)) {
|
||||
onlineMeter.mark();
|
||||
} else {
|
||||
offlineMeter.mark();
|
||||
storedMessages.insert(account.getId(), device.getId(), pendingMessage);
|
||||
storedMessages.insert(address, pendingMessage);
|
||||
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
|
||||
@@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Optional;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
@@ -30,9 +29,6 @@ public class Account {
|
||||
|
||||
public static final int MEMCACHE_VERION = 5;
|
||||
|
||||
@JsonIgnore
|
||||
private long id;
|
||||
|
||||
@JsonProperty
|
||||
private String number;
|
||||
|
||||
@@ -57,14 +53,6 @@ public class Account {
|
||||
this.devices = devices;
|
||||
}
|
||||
|
||||
public long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public Optional<Device> getAuthenticatedDevice() {
|
||||
return authenticatedDevice;
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ public abstract class Accounts {
|
||||
{
|
||||
try {
|
||||
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
|
||||
account.setId(resultSet.getLong(ID));
|
||||
// account.setId(resultSet.getLong(ID));
|
||||
|
||||
return account;
|
||||
} catch (IOException e) {
|
||||
|
||||
@@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
|
||||
|
||||
public class PubSubManager {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
private final SubscriptionListener baseListener = new SubscriptionListener();
|
||||
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
|
||||
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
private final SubscriptionListener baseListener = new SubscriptionListener();
|
||||
private final Map<String, PubSubListener> listeners = new HashMap<>();
|
||||
|
||||
private final JedisPool jedisPool;
|
||||
private boolean subscribed = false;
|
||||
@@ -33,25 +35,29 @@ public class PubSubManager {
|
||||
}
|
||||
|
||||
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
listeners.put(address, listener);
|
||||
baseListener.subscribe(address.toString());
|
||||
listeners.put(address.serialize(), listener);
|
||||
baseListener.subscribe(address.serialize());
|
||||
}
|
||||
|
||||
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
if (listeners.get(address) == listener) {
|
||||
listeners.remove(address);
|
||||
baseListener.unsubscribe(address.toString());
|
||||
if (listeners.get(address.serialize()) == listener) {
|
||||
listeners.remove(address.serialize());
|
||||
baseListener.unsubscribe(address.serialize());
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
|
||||
return publish(address.serialize(), message);
|
||||
}
|
||||
|
||||
private synchronized boolean publish(String channel, PubSubMessage message) {
|
||||
try {
|
||||
String serialized = mapper.writeValueAsString(message);
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
return jedis.publish(address.toString(), serialized) != 0;
|
||||
return jedis.publish(channel, serialized) != 0;
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
@@ -79,7 +85,7 @@ public class PubSubManager {
|
||||
Jedis jedis = null;
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
|
||||
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
|
||||
logger.warn("**** Unsubscribed from holding channel!!! ******");
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
@@ -95,7 +101,7 @@ public class PubSubManager {
|
||||
for (;;) {
|
||||
try {
|
||||
Thread.sleep(20000);
|
||||
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
|
||||
publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
|
||||
} catch (InterruptedException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
@@ -109,18 +115,15 @@ public class PubSubManager {
|
||||
@Override
|
||||
public void onMessage(String channel, String message) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
PubSubListener listener;
|
||||
|
||||
synchronized (PubSubManager.this) {
|
||||
listener = listeners.get(address);
|
||||
listener = listeners.get(channel);
|
||||
}
|
||||
|
||||
if (listener != null) {
|
||||
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Address", e);
|
||||
} catch (IOException e) {
|
||||
logger.warn("IOE", e);
|
||||
}
|
||||
@@ -133,17 +136,11 @@ public class PubSubManager {
|
||||
|
||||
@Override
|
||||
public void onSubscribe(String channel, int count) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
|
||||
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
|
||||
synchronized (PubSubManager.this) {
|
||||
subscribed = true;
|
||||
PubSubManager.this.notifyAll();
|
||||
}
|
||||
if (KEEPALIVE_CHANNEL.equals(channel)) {
|
||||
synchronized (PubSubManager.this) {
|
||||
subscribed = true;
|
||||
PubSubManager.this.notifyAll();
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Weird address", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.entities.PendingMessage;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedList;
|
||||
@@ -53,18 +54,30 @@ public class StoredMessages {
|
||||
this.jedisPool = jedisPool;
|
||||
}
|
||||
|
||||
public void insert(long accountId, long deviceId, PendingMessage message) {
|
||||
public void clear(WebsocketAddress address) {
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
jedis.del(getKey(address));
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
}
|
||||
}
|
||||
|
||||
public void insert(WebsocketAddress address, PendingMessage message) {
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
|
||||
String serializedMessage = mapper.writeValueAsString(message);
|
||||
long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage);
|
||||
long queueSize = jedis.lpush(getKey(address), serializedMessage);
|
||||
queueSizeHistogram.update(queueSize);
|
||||
|
||||
if (queueSize > 1000) {
|
||||
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
|
||||
jedis.ltrim(getKey(address), 0, 999);
|
||||
}
|
||||
|
||||
} catch (JsonProcessingException e) {
|
||||
@@ -75,7 +88,7 @@ public class StoredMessages {
|
||||
}
|
||||
}
|
||||
|
||||
public List<PendingMessage> getMessagesForDevice(long accountId, long deviceId) {
|
||||
public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
|
||||
List<PendingMessage> messages = new LinkedList<>();
|
||||
Jedis jedis = null;
|
||||
|
||||
@@ -83,7 +96,7 @@ public class StoredMessages {
|
||||
jedis = jedisPool.getResource();
|
||||
String message;
|
||||
|
||||
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) {
|
||||
while ((message = jedis.rpop(getKey(address))) != null) {
|
||||
try {
|
||||
messages.add(mapper.readValue(message, PendingMessage.class));
|
||||
} catch (IOException e) {
|
||||
@@ -98,8 +111,8 @@ public class StoredMessages {
|
||||
}
|
||||
}
|
||||
|
||||
private String getKey(long accountId, long deviceId) {
|
||||
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId;
|
||||
private String getKey(WebsocketAddress address) {
|
||||
return QUEUE_PREFIX + ":" + address.serialize();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
public class WebsocketAddress {
|
||||
|
||||
private final long accountId;
|
||||
private final long deviceId;
|
||||
private final String number;
|
||||
private final long deviceId;
|
||||
|
||||
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
|
||||
try {
|
||||
String[] parts = serialized.split(":");
|
||||
|
||||
if (parts == null || parts.length != 2) {
|
||||
throw new InvalidWebsocketAddressException(serialized);
|
||||
}
|
||||
|
||||
this.accountId = Long.parseLong(parts[0]);
|
||||
this.deviceId = Long.parseLong(parts[1]);
|
||||
} catch (NumberFormatException e) {
|
||||
throw new InvalidWebsocketAddressException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public WebsocketAddress(long accountId, long deviceId) {
|
||||
this.accountId = accountId;
|
||||
public WebsocketAddress(String number, long deviceId) {
|
||||
this.number = number;
|
||||
this.deviceId = deviceId;
|
||||
}
|
||||
|
||||
public long getAccountId() {
|
||||
return accountId;
|
||||
}
|
||||
|
||||
public long getDeviceId() {
|
||||
return deviceId;
|
||||
public String serialize() {
|
||||
return number + ":" + deviceId;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return accountId + ":" + deviceId;
|
||||
return serialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -45,13 +26,13 @@ public class WebsocketAddress {
|
||||
WebsocketAddress that = (WebsocketAddress)other;
|
||||
|
||||
return
|
||||
this.accountId == that.accountId &&
|
||||
this.number.equals(that.number) &&
|
||||
this.deviceId == that.deviceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return (int)accountId ^ (int)deviceId;
|
||||
return number.hashCode() ^ (int)deviceId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user