Break out into a multi-module project

This commit is contained in:
Moxie Marlinspike
2019-04-20 21:56:20 -07:00
parent b41dde777e
commit d0d375aeb7
318 changed files with 255 additions and 215 deletions

View File

@@ -0,0 +1,84 @@
package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.security.SecureRandom;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration" ));
private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration"));
private final PushSender pushSender;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final PubSubManager pubSubManager;
private final ApnFallbackManager apnFallbackManager;
public AuthenticatedConnectListener(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
PubSubManager pubSubManager,
ApnFallbackManager apnFallbackManager)
{
this.pushSender = pushSender;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.pubSubManager = pubSubManager;
this.apnFallbackManager = apnFallbackManager;
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
if (context.getAuthenticated() != null) {
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final String connectionId = String.valueOf(new SecureRandom().nextLong());
final Timer.Context timer = durationTimer.time();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient(), connectionId);
final PubSubMessage connectMessage = PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED)
.setContent(ByteString.copyFrom(connectionId.getBytes()))
.build();
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device));
pubSubManager.publish(address, connectMessage);
pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
pubSubManager.unsubscribe(address, connection);
timer.stop();
}
});
} else {
final Timer.Context timer = unauthenticatedDurationTimer.time();
context.addListener((context1, statusCode, reason) -> timer.stop());
}
}
}

View File

@@ -0,0 +1,51 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public class DeadLetterHandler implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class);
private final MessagesManager messagesManager;
public DeadLetterHandler(MessagesManager messagesManager) {
this.messagesManager = messagesManager;
}
@Override
public void onDispatchMessage(String channel, byte[] data) {
try {
logger.info("Handling dead letter to: " + channel);
WebsocketAddress address = new WebsocketAddress(channel);
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data);
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.DELIVER_VALUE:
Envelope message = Envelope.parseFrom(pubSubMessage.getContent());
messagesManager.insert(address.getNumber(), address.getDeviceId(), message);
break;
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Bad pubsub message", e);
} catch (InvalidWebsocketAddressException e) {
logger.warn("Invalid websocket address", e);
}
}
@Override
public void onDispatchSubscribed(String channel) {
logger.warn("DeadLetterHandler subscription notice! " + channel);
}
@Override
public void onDispatchUnsubscribed(String channel) {
logger.warn("DeadLetterHandler unsubscribe notice! " + channel);
}
}

View File

@@ -0,0 +1,11 @@
package org.whispersystems.textsecuregcm.websocket;
public class InvalidWebsocketAddressException extends Exception {
public InvalidWebsocketAddressException(String serialized) {
super(serialized);
}
public InvalidWebsocketAddressException(Exception e) {
super(e);
}
}

View File

@@ -0,0 +1,33 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.util.Base64;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
public class ProvisioningAddress extends WebsocketAddress {
public ProvisioningAddress(String address, int id) throws InvalidWebsocketAddressException {
super(address, id);
}
public ProvisioningAddress(String serialized) throws InvalidWebsocketAddressException {
super(serialized);
}
public String getAddress() {
return getNumber();
}
public static ProvisioningAddress generate() {
try {
byte[] random = new byte[16];
new SecureRandom().nextBytes(random);
return new ProvisioningAddress(Base64.encodeBytesWithoutPadding(random)
.replace('+', '-').replace('/', '_'), 0);
} catch (InvalidWebsocketAddressException e) {
throw new AssertionError(e);
}
}
}

View File

@@ -0,0 +1,29 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
public class ProvisioningConnectListener implements WebSocketConnectListener {
private final PubSubManager pubSubManager;
public ProvisioningConnectListener(PubSubManager pubSubManager) {
this.pubSubManager = pubSubManager;
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
pubSubManager.subscribe(provisioningAddress, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
pubSubManager.unsubscribe(provisioningAddress, connection);
}
});
}
}

View File

@@ -0,0 +1,72 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import java.util.Optional;
public class ProvisioningConnection implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
private final WebSocketClient client;
public ProvisioningConnection(WebSocketClient client) {
this.client = client;
}
@Override
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", null, body);
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
client.close(1001, "All you get.");
}
@Override
public void onFailure(Throwable throwable) {
client.close(1001, "That's all!");
}
});
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf Error: ", e);
}
}
@Override
public void onDispatchSubscribed(String channel) {
try {
ProvisioningAddress address = new ProvisioningAddress(channel);
this.client.sendRequest("PUT", "/v1/address", null, Optional.of(ProvisioningUuid.newBuilder()
.setUuid(address.getAddress())
.build()
.toByteArray()));
} catch (InvalidWebsocketAddressException e) {
logger.warn("Badly formatted address", e);
this.client.close(1001, "Server Error");
}
}
@Override
public void onDispatchUnsubscribed(String channel) {
this.client.close(1001, "Closed");
}
}

View File

@@ -0,0 +1,47 @@
package org.whispersystems.textsecuregcm.websocket;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import io.dropwizard.auth.basic.BasicCredentials;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Account> {
private final AccountAuthenticator accountAuthenticator;
public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
}
@Override
public AuthenticationResult<Account> authenticate(UpgradeRequest request) throws AuthenticationException {
try {
Map<String, List<String>> parameters = request.getParameterMap();
List<String> usernames = parameters.get("login");
List<String> passwords = parameters.get("password");
if (usernames == null || usernames.size() == 0 ||
passwords == null || passwords.size() == 0)
{
return new AuthenticationResult<>(Optional.empty(), false);
}
BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
} catch (io.dropwizard.auth.AuthenticationException e) {
throw new AuthenticationException(e);
}
}
}

View File

@@ -0,0 +1,227 @@
package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.ws.rs.WebApplicationException;
import java.util.Collections;
import java.util.Iterator;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketConnection implements DispatchChannel {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration"));
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final ReceiptSender receiptSender;
private final PushSender pushSender;
private final MessagesManager messagesManager;
private final Account account;
private final Device device;
private final WebSocketClient client;
private final String connectionId;
public WebSocketConnection(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
Account account,
Device device,
WebSocketClient client,
String connectionId)
{
this.pushSender = pushSender;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.account = account;
this.device = device;
this.client = client;
this.connectionId = connectionId;
}
@Override
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message);
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.QUERY_DB_VALUE:
processStoredMessages();
break;
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty(), false);
break;
case PubSubMessage.Type.CONNECTED_VALUE:
if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) {
client.hardDisconnectQuietly();
}
break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf parse error", e);
}
}
@Override
public void onDispatchUnsubscribed(String channel) {
client.close(1000, "OK");
}
public void onDispatchSubscribed(String channel) {
processStoredMessages();
}
private void sendMessage(final Envelope message,
final Optional<StoredMessageInfo> storedMessageInfo,
final boolean requery)
{
try {
String header;
Optional<byte[]> body;
if (Util.isEmpty(device.getSignalingKey())) {
header = "X-Signal-Key: false";
body = Optional.ofNullable(message.toByteArray());
} else {
header = "X-Signal-Key: true";
body = Optional.ofNullable(new EncryptedOutgoingMessage(message, device.getSignalingKey()).toByteArray());
}
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/api/v1/message", Collections.singletonList(header), body);
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(@Nullable WebSocketResponseMessage response) {
boolean isReceipt = message.getType() == Envelope.Type.RECEIPT;
if (isSuccessResponse(response) && !isReceipt) {
messageTime.update(System.currentTimeMillis() - message.getTimestamp());
}
if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached);
if (!isReceipt) sendDeliveryReceiptFor(message);
if (requery) processStoredMessages();
} else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) {
requeueMessage(message);
}
}
@Override
public void onFailure(@Nonnull Throwable throwable) {
if (!storedMessageInfo.isPresent()) requeueMessage(message);
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
});
} catch (CryptoEncodingException e) {
logger.warn("Bad signaling key", e);
}
}
private void requeueMessage(Envelope message) {
pushSender.getWebSocketSender().queueMessage(account, device, message);
try {
pushSender.sendQueuedNotification(account, device);
} catch (NotPushRegisteredException e) {
logger.warn("requeueMessage", e);
}
}
private void sendDeliveryReceiptFor(Envelope message) {
if (!message.hasSource()) return;
try {
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp());
} catch (NoSuchUserException | NotPushRegisteredException e) {
logger.info("No longer registered " + e.getMessage());
} catch (WebApplicationException e) {
logger.warn("Bad federated response for receipt: " + e.getResponse().getStatus());
}
}
private void processStoredMessages() {
OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), device.getId());
Iterator<OutgoingMessageEntity> iterator = messages.getMessages().iterator();
while (iterator.hasNext()) {
OutgoingMessageEntity message = iterator.next();
Envelope.Builder builder = Envelope.newBuilder()
.setType(Envelope.Type.valueOf(message.getType()))
.setTimestamp(message.getTimestamp())
.setServerTimestamp(message.getServerTimestamp());
if (!Util.isEmpty(message.getSource())) {
builder.setSource(message.getSource())
.setSourceDevice(message.getSourceDevice());
}
if (message.getMessage() != null) {
builder.setLegacyMessage(ByteString.copyFrom(message.getMessage()));
}
if (message.getContent() != null) {
builder.setContent(ByteString.copyFrom(message.getContent()));
}
if (message.getRelay() != null && !message.getRelay().isEmpty()) {
builder.setRelay(message.getRelay());
}
sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore());
}
if (!messages.hasMore()) {
client.sendRequest("PUT", "/api/v1/queue/empty", null, Optional.empty());
}
}
private static class StoredMessageInfo {
private final long id;
private final boolean cached;
private StoredMessageInfo(long id, boolean cached) {
this.id = id;
this.cached = cached;
}
}
}

View File

@@ -0,0 +1,63 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.storage.PubSubAddress;
public class WebsocketAddress implements PubSubAddress {
private final String number;
private final long deviceId;
public WebsocketAddress(String number, long deviceId) {
this.number = number;
this.deviceId = deviceId;
}
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
try {
String[] parts = serialized.split(":", 2);
if (parts.length != 2) {
throw new InvalidWebsocketAddressException("Bad address: " + serialized);
}
this.number = parts[0];
this.deviceId = Long.parseLong(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public String serialize() {
return number + ":" + deviceId;
}
public String toString() {
return serialize();
}
@Override
public boolean equals(Object other) {
if (other == null) return false;
if (!(other instanceof WebsocketAddress)) return false;
WebsocketAddress that = (WebsocketAddress)other;
return
this.number.equals(that.number) &&
this.deviceId == that.deviceId;
}
@Override
public int hashCode() {
return number.hashCode() ^ (int)deviceId;
}
}