Simplify message handling by returning early and throwing out maps

This commit is contained in:
Matt Corallo
2014-01-11 16:30:37 -10:00
parent 7af3c51cc4
commit eedaa8b3f4
7 changed files with 244 additions and 202 deletions

View File

@@ -45,7 +45,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.IterablePair;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
@@ -54,12 +53,10 @@ import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.Path;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -117,16 +114,9 @@ public class MessageController extends HttpServlet {
try {
Device sender = authenticate(req);
IncomingMessageList messages = parseIncomingMessages(req);
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
List<String> numbersMissingDevices = new LinkedList<>();
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages =
getOutgoingMessageSignals(sender.getNumber(), messages.getMessages(), numbersMissingDevices);
handleAsyncDelivery(timerContext, req.startAsync(), outgoingMessages, numbersMissingDevices);
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
@@ -149,19 +139,32 @@ public class MessageController extends HttpServlet {
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> listPair,
final List<String> numbersMissingDevices)
final Device sender,
final IncomingMessageList messages)
{
executor.submit(new Runnable() {
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>(numbersMissingDevices);
List<String> failure = new LinkedList<>();
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
try {
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
} catch (MissingDevicesException e) {
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
failureMeter.mark();
timerContext.stop();
return;
}
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : listPair) {
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
String relay = messagePair.first().relay;
if (Util.isEmpty(relay)) {
@@ -199,8 +202,6 @@ public class MessageController extends HttpServlet {
success.add(string);
for (String string : relayResponse.getFailure())
failure.add(string);
for (String string : relayResponse.getNumbersMissingDevices())
numbersMissingDevices.add(string);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
@@ -208,7 +209,7 @@ public class MessageController extends HttpServlet {
}
}
byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices));
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
@@ -232,30 +233,16 @@ public class MessageController extends HttpServlet {
@Nullable
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages,
List<String> numbersMissingDevices)
List<IncomingMessage> incomingMessages)
throws MissingDevicesException
{
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
Map<String, Set<Long>> localDestinations = new HashMap<>();
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
Set<String> destinationNumbers = new HashSet<>();
for (IncomingMessage incoming : incomingMessages) {
for (IncomingMessage incoming : incomingMessages)
destinationNumbers.add(incoming.getDestination());
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
}
deviceIds.add(incoming.getDestinationDeviceId());
}
Pair<Map<String, Account>, List<String>> accountsForDevices = accountsManager.getAccountsForDevices(localDestinations);
Map<String, Account> localAccounts = accountsForDevices.first();
for (String number : accountsForDevices.second())
numbersMissingDevices.add(number);
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
@@ -281,7 +268,14 @@ public class MessageController extends HttpServlet {
if (!Util.isEmpty(incoming.getRelay()))
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
else {
Account destination = localAccounts.get(incoming.getDestination());
Account destination = null;
for (Account account : localAccounts) {
if (account.getNumber().equals(incoming.getDestination())) {
destination = account;
break;
}
}
if (destination != null)
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
}
@@ -293,6 +287,24 @@ public class MessageController extends HttpServlet {
return outgoingMessages;
}
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
}
deviceIds.add(incoming.getDestinationDeviceId());
}
return localDestinations;
}
private byte[] getMessageBody(IncomingMessage message) {
try {
return Base64.decode(message.getBody());