Initial multi device support refactoring.

1) Store account data as a json type, which includes all
   devices in a single object.

2) Simplify message delivery logic.

3) Make federated calls a pass through to standard controllers.

4) Simplify key retrieval logic.
This commit is contained in:
Moxie Marlinspike
2014-01-18 23:45:07 -08:00
parent 6f9226dcf9
commit 74f71fd8a6
47 changed files with 961 additions and 1211 deletions

View File

@@ -16,383 +16,226 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.google.protobuf.ByteString;
import com.yammer.dropwizard.auth.AuthenticationException;
import com.yammer.dropwizard.auth.basic.BasicCredentials;
import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import com.yammer.metrics.core.Timer;
import com.yammer.metrics.core.TimerContext;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.MissingDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class MessageController extends HttpServlet {
@Path("/v1/messages")
public class MessageController {
public static final String PATH = "/v1/messages/";
private final Meter successMeter = Metrics.newMeter(MessageController.class, "deliver_message", "success", TimeUnit.MINUTES);
private final Meter failureMeter = Metrics.newMeter(MessageController.class, "deliver_message", "failure", TimeUnit.MINUTES);
private final Timer timer = Metrics.newTimer(MessageController.class, "deliver_message_time", TimeUnit.MILLISECONDS, TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters;
private final DeviceAuthenticator deviceAuthenticator;
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
private final ObjectMapper objectMapper;
private final ExecutorService executor;
private final AccountsManager accountsManager;
public MessageController(RateLimiters rateLimiters,
DeviceAuthenticator deviceAuthenticator,
PushSender pushSender,
AccountsManager accountsManager,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.deviceAuthenticator = deviceAuthenticator;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager;
this.objectMapper = new ObjectMapper();
this.executor = Executors.newFixedThreadPool(10);
}
class LocalOrRemoteDevice {
Device device;
String relay, number; long deviceId;
LocalOrRemoteDevice(Device device) {
this.device = device; this.number = device.getNumber(); this.deviceId = device.getDeviceId();
}
LocalOrRemoteDevice(String relay, String number, long deviceId) {
this.relay = relay; this.number = number; this.deviceId = deviceId;
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
TimerContext timerContext = timer.time();
@Timed
@Path("/{destination}")
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void sendMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
try {
Device sender = authenticate(req);
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(401);
} catch (ValidationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(415);
} catch (IOException e) {
logger.warn("IOE", e);
failureMeter.mark();
timerContext.stop();
resp.setStatus(501);
} catch (RateLimitExceededException e) {
timerContext.stop();
failureMeter.mark();
resp.setStatus(413);
if (messages.getRelay() != null) sendLocalMessage(source, destinationName, messages);
else sendRelayMessage(source, destinationName, messages);
} catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
} catch (MissingDevicesException e) {
throw new WebApplicationException(Response.status(409)
.entity(new MissingDevices(e.getMissingDevices()))
.build());
}
}
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final Device sender,
final IncomingMessageList messages)
@Timed
@Path("/")
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse sendMessageLegacy(@Auth Account source, @Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
executor.submit(new Runnable() {
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try {
List<IncomingMessage> incomingMessages = messages.getMessages();
validateLegacyDestinations(incomingMessages);
try {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
try {
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
} catch (MissingDevicesException e) {
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
failureMeter.mark();
timerContext.stop();
return;
}
messages.setRelay(incomingMessages.get(0).getRelay());
sendMessage(source, incomingMessages.get(0).getDestination(), messages);
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
String relay = messagePair.first().relay;
if (Util.isEmpty(relay)) {
String encodedId = messagePair.first().device.getBackwardsCompatibleNumberEncoding();
try {
pushSender.sendMessage(messagePair.first().device, messagePair.second());
success.add(encodedId);
} catch (NoSuchUserException e) {
logger.debug("No such user", e);
failure.add(encodedId);
}
} else {
Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> messageSet = relayMessages.get(relay);
if (messageSet == null) {
messageSet = new HashSet<>();
relayMessages.put(relay, messageSet);
}
messageSet.add(messagePair);
}
}
for (Map.Entry<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> messagesForRelay : relayMessages.entrySet()) {
try {
FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey());
List<RelayMessage> messages = new LinkedList<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> message : messagesForRelay.getValue()) {
messages.add(new RelayMessage(message.first().number,
message.first().deviceId,
message.second().toByteArray()));
}
MessageResponse relayResponse = client.sendMessages(messages);
for (String string : relayResponse.getSuccess())
success.add(string);
for (String string : relayResponse.getFailure())
failure.add(string);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
failure.add(messagePair.first().number);
}
}
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
successMeter.mark();
} catch (IOException e) {
logger.warn("Async Handler", e);
failureMeter.mark();
response.setStatus(501);
context.complete();
} catch (Exception e) {
logger.error("Unknown error sending message", e);
failureMeter.mark();
response.setStatus(500);
context.complete();
}
timerContext.stop();
}
});
return new MessageResponse(new LinkedList<String>(), new LinkedList<String>());
} catch (ValidationException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
@Nullable
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages)
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws NoSuchUserException, MissingDevicesException, IOException
{
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) {
sendLocalMessage(source, destination, destinationDevice.get(), incomingMessage);
}
}
}
private void sendLocalMessage(Account source,
Account destinationAccount,
Device destinationDevice,
IncomingMessage incomingMessage)
throws NoSuchUserException, IOException
{
try {
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
OutgoingMessageSignal.Builder messageBuilder = OutgoingMessageSignal.newBuilder();
messageBuilder.setType(incomingMessage.getType())
.setSource(source.getNumber())
.setTimestamp(System.currentTimeMillis());
if (messageBody.isPresent()) {
messageBuilder.setMessage(ByteString.copyFrom(messageBody.get()));
}
if (source.getRelay().isPresent()) {
messageBuilder.setRelay(source.getRelay().get());
}
pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build());
} catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) throw new NoSuchUserException(e);
else logger.debug("Not registered", e);
} catch (TransientPushFailureException e) {
if (destinationDevice.isMaster()) throw new IOException(e);
else logger.debug("Transient failure", e);
}
}
private void sendRelayMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws IOException, NoSuchUserException
{
try {
FederatedClient client = federatedClientManager.getClient(messages.getRelay());
client.sendMessages(source.getNumber(), destinationName, messages);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountsManager.get(destination);
if (!account.isPresent() || !account.get().isActive()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
throws MissingDevicesException
{
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
Set<Long> destinationDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
Set<String> destinationNumbers = new HashSet<>();
for (IncomingMessage incoming : incomingMessages)
destinationNumbers.add(incoming.getDestination());
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
outgoingMessage.setType(incoming.getType());
outgoingMessage.setSource(sourceNumber);
byte[] messageBody = getMessageBody(incoming);
if (messageBody != null) {
outgoingMessage.setMessage(ByteString.copyFrom(messageBody));
}
outgoingMessage.setTimestamp(System.currentTimeMillis());
for (String destination : destinationNumbers) {
if (!destination.equals(incoming.getDestination()))
outgoingMessage.addDestinations(destination);
}
LocalOrRemoteDevice device = null;
if (!Util.isEmpty(incoming.getRelay()))
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
else {
Account destination = null;
for (Account account : localAccounts) {
if (account.getNumber().equals(incoming.getDestination())) {
destination = account;
break;
}
}
if (destination != null)
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
}
if (device != null)
outgoingMessages.add(new Pair<>(device, outgoingMessage.build()));
for (IncomingMessage message : messages) {
destinationDeviceIds.add(message.getDestinationDeviceId());
}
return outgoingMessages;
}
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
for (Device device : account.getDevices()) {
if (!destinationDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
deviceIds.add(incoming.getDestinationDeviceId());
}
return localDestinations;
}
private byte[] getMessageBody(IncomingMessage message) {
try {
return Base64.decode(message.getBody());
} catch (IOException ioe) {
ioe.printStackTrace();
return null;
if (!missingDeviceIds.isEmpty()) {
throw new MissingDevicesException(missingDeviceIds);
}
}
private byte[] serializeResponse(MessageResponse response) throws IOException {
try {
return objectMapper.writeValueAsBytes(response);
} catch (JsonProcessingException e) {
throw new IOException(e);
}
}
private IncomingMessageList parseIncomingMessages(HttpServletRequest request)
throws IOException, ValidationException
private void validateLegacyDestinations(List<IncomingMessage> messages)
throws ValidationException
{
BufferedReader reader = request.getReader();
StringBuilder content = new StringBuilder();
String line;
String destination = null;
while ((line = reader.readLine()) != null) {
content.append(line);
for (IncomingMessage message : messages) {
if (destination != null && !destination.equals(message.getDestination())) {
throw new ValidationException("Multiple account destinations!");
}
destination = message.getDestination();
}
IncomingMessageList messages = objectMapper.readValue(content.toString(),
IncomingMessageList.class);
if (messages.getMessages() == null) {
throw new ValidationException();
}
for (IncomingMessage message : messages.getMessages()) {
if (message.getBody() == null) throw new ValidationException();
if (message.getDestination() == null) throw new ValidationException();
}
return messages;
}
private Device authenticate(HttpServletRequest request) throws AuthenticationException {
private Optional<byte[]> getMessageBody(IncomingMessage message) {
try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromFullHeader(request.getHeader("Authorization"));
BasicCredentials credentials = new BasicCredentials(authorizationHeader.getNumber() + "." + authorizationHeader.getDeviceId(),
authorizationHeader.getPassword() );
Optional<Device> account = deviceAuthenticator.authenticate(credentials);
if (account.isPresent()) return account.get();
else throw new AuthenticationException("Bad credentials");
} catch (InvalidAuthorizationHeaderException e) {
throw new AuthenticationException(e);
return Optional.of(Base64.decode(message.getBody()));
} catch (IOException ioe) {
logger.debug("Bad B64", ioe);
return Optional.absent();
}
}
// @Timed
// @POST
// @Consumes(MediaType.APPLICATION_JSON)
// @Produces(MediaType.APPLICATION_JSON)
// public MessageResponse sendMessage(@Auth Device sender, IncomingMessageList messages)
// throws IOException
// {
// List<String> success = new LinkedList<>();
// List<String> failure = new LinkedList<>();
// List<IncomingMessage> incomingMessages = messages.getMessages();
// List<OutgoingMessageSignal> outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), incomingMessages);
//
// IterablePair<IncomingMessage, OutgoingMessageSignal> listPair = new IterablePair<>(incomingMessages, outgoingMessages);
//
// for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : listPair) {
// String destination = messagePair.first().getDestination();
// String relay = messagePair.first().getRelay();
//
// try {
// if (Util.isEmpty(relay)) sendLocalMessage(destination, messagePair.second());
// else sendRelayMessage(relay, destination, messagePair.second());
// success.add(destination);
// } catch (NoSuchUserException e) {
// logger.debug("No such user", e);
// failure.add(destination);
// }
// }
//
// return new MessageResponse(success, failure);
// }
}