Fix for PubSub channel.

1) Create channels based on numbers rather than DB row ids.

2) Ensure that stored messages are cleared at reregistration
   time.
This commit is contained in:
Moxie Marlinspike
2014-07-26 20:41:25 -07:00
parent 4eb88a3e02
commit c9a1386a55
12 changed files with 77 additions and 91 deletions

View File

@@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
@@ -30,9 +29,6 @@ public class Account {
public static final int MEMCACHE_VERION = 5;
@JsonIgnore
private long id;
@JsonProperty
private String number;
@@ -57,14 +53,6 @@ public class Account {
this.devices = devices;
}
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice;
}

View File

@@ -92,7 +92,7 @@ public abstract class Accounts {
{
try {
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setId(resultSet.getLong(ID));
// account.setId(resultSet.getLong(ID));
return account;
} catch (IOException e) {

View File

@@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
public class PubSubManager {
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<String, PubSubListener> listeners = new HashMap<>();
private final JedisPool jedisPool;
private boolean subscribed = false;
@@ -33,25 +35,29 @@ public class PubSubManager {
}
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
listeners.put(address, listener);
baseListener.subscribe(address.toString());
listeners.put(address.serialize(), listener);
baseListener.subscribe(address.serialize());
}
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
if (listeners.get(address) == listener) {
listeners.remove(address);
baseListener.unsubscribe(address.toString());
if (listeners.get(address.serialize()) == listener) {
listeners.remove(address.serialize());
baseListener.unsubscribe(address.serialize());
}
}
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize(), message);
}
private synchronized boolean publish(String channel, PubSubMessage message) {
try {
String serialized = mapper.writeValueAsString(message);
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
return jedis.publish(address.toString(), serialized) != 0;
return jedis.publish(channel, serialized) != 0;
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
@@ -79,7 +85,7 @@ public class PubSubManager {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******");
} finally {
if (jedis != null)
@@ -95,7 +101,7 @@ public class PubSubManager {
for (;;) {
try {
Thread.sleep(20000);
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
} catch (InterruptedException e) {
throw new AssertionError(e);
}
@@ -109,18 +115,15 @@ public class PubSubManager {
@Override
public void onMessage(String channel, String message) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
PubSubListener listener;
synchronized (PubSubManager.this) {
listener = listeners.get(address);
listener = listeners.get(channel);
}
if (listener != null) {
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Address", e);
} catch (IOException e) {
logger.warn("IOE", e);
}
@@ -133,17 +136,11 @@ public class PubSubManager {
@Override
public void onSubscribe(String channel, int count) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
if (KEEPALIVE_CHANNEL.equals(channel)) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Weird address", e);
}
}

View File

@@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.LinkedList;
@@ -53,18 +54,30 @@ public class StoredMessages {
this.jedisPool = jedisPool;
}
public void insert(long accountId, long deviceId, PendingMessage message) {
public void clear(WebsocketAddress address) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.del(getKey(address));
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
public void insert(WebsocketAddress address, PendingMessage message) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
String serializedMessage = mapper.writeValueAsString(message);
long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage);
long queueSize = jedis.lpush(getKey(address), serializedMessage);
queueSizeHistogram.update(queueSize);
if (queueSize > 1000) {
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
jedis.ltrim(getKey(address), 0, 999);
}
} catch (JsonProcessingException e) {
@@ -75,7 +88,7 @@ public class StoredMessages {
}
}
public List<PendingMessage> getMessagesForDevice(long accountId, long deviceId) {
public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
List<PendingMessage> messages = new LinkedList<>();
Jedis jedis = null;
@@ -83,7 +96,7 @@ public class StoredMessages {
jedis = jedisPool.getResource();
String message;
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) {
while ((message = jedis.rpop(getKey(address))) != null) {
try {
messages.add(mapper.readValue(message, PendingMessage.class));
} catch (IOException e) {
@@ -98,8 +111,8 @@ public class StoredMessages {
}
}
private String getKey(long accountId, long deviceId) {
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId;
private String getKey(WebsocketAddress address) {
return QUEUE_PREFIX + ":" + address.serialize();
}
}