mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-25 18:08:02 +01:00
Simplify message handling by returning early and throwing out maps
This commit is contained in:
@@ -60,6 +60,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
@Path("/v1/federation")
|
||||
public class FederationController {
|
||||
|
||||
@@ -125,17 +127,24 @@ public class FederationController {
|
||||
deviceIds.add(message.getDestinationDeviceId());
|
||||
}
|
||||
|
||||
Pair<Map<String, Account>, List<String>> accountsForDevices = accounts.getAccountsForDevices(localDestinations);
|
||||
List<Account> localAccounts = null;
|
||||
try {
|
||||
localAccounts = accounts.getAccountsForDevices(localDestinations);
|
||||
} catch (MissingDevicesException e) {
|
||||
return new MessageResponse(e.missingNumbers);
|
||||
}
|
||||
|
||||
Map<String, Account> localAccounts = accountsForDevices.first();
|
||||
List<String> numbersMissingDevices = accountsForDevices.second();
|
||||
List<String> success = new LinkedList<>();
|
||||
List<String> failure = new LinkedList<>(numbersMissingDevices);
|
||||
List<String> failure = new LinkedList<>();
|
||||
|
||||
for (RelayMessage message : messages) {
|
||||
Account destinationAccount = localAccounts.get(message.getDestination());
|
||||
if (destinationAccount == null)
|
||||
continue;
|
||||
Account destinationAccount = null;
|
||||
for (Account account : localAccounts)
|
||||
if (account.getNumber().equals(message.getDestination()))
|
||||
destinationAccount= account;
|
||||
|
||||
checkState(destinationAccount != null);
|
||||
|
||||
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
|
||||
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
|
||||
.toBuilder()
|
||||
@@ -150,7 +159,7 @@ public class FederationController {
|
||||
}
|
||||
}
|
||||
|
||||
return new MessageResponse(success, failure, numbersMissingDevices);
|
||||
return new MessageResponse(success, failure);
|
||||
} catch (InvalidProtocolBufferException ipe) {
|
||||
logger.warn("ProtoBuf", ipe);
|
||||
throw new WebApplicationException(Response.status(400).build());
|
||||
|
||||
@@ -45,7 +45,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.Base64;
|
||||
import org.whispersystems.textsecuregcm.util.IterablePair;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
@@ -54,12 +53,10 @@ import javax.servlet.AsyncContext;
|
||||
import javax.servlet.http.HttpServlet;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.Path;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -117,16 +114,9 @@ public class MessageController extends HttpServlet {
|
||||
|
||||
try {
|
||||
Device sender = authenticate(req);
|
||||
IncomingMessageList messages = parseIncomingMessages(req);
|
||||
|
||||
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
|
||||
|
||||
List<String> numbersMissingDevices = new LinkedList<>();
|
||||
|
||||
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages =
|
||||
getOutgoingMessageSignals(sender.getNumber(), messages.getMessages(), numbersMissingDevices);
|
||||
|
||||
handleAsyncDelivery(timerContext, req.startAsync(), outgoingMessages, numbersMissingDevices);
|
||||
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
|
||||
} catch (AuthenticationException e) {
|
||||
failureMeter.mark();
|
||||
timerContext.stop();
|
||||
@@ -149,19 +139,32 @@ public class MessageController extends HttpServlet {
|
||||
|
||||
private void handleAsyncDelivery(final TimerContext timerContext,
|
||||
final AsyncContext context,
|
||||
final List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> listPair,
|
||||
final List<String> numbersMissingDevices)
|
||||
final Device sender,
|
||||
final IncomingMessageList messages)
|
||||
{
|
||||
executor.submit(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
List<String> success = new LinkedList<>();
|
||||
List<String> failure = new LinkedList<>(numbersMissingDevices);
|
||||
List<String> failure = new LinkedList<>();
|
||||
HttpServletResponse response = (HttpServletResponse) context.getResponse();
|
||||
|
||||
try {
|
||||
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
|
||||
try {
|
||||
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
|
||||
} catch (MissingDevicesException e) {
|
||||
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
|
||||
response.setContentLength(responseData.length);
|
||||
response.getOutputStream().write(responseData);
|
||||
context.complete();
|
||||
failureMeter.mark();
|
||||
timerContext.stop();
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
|
||||
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : listPair) {
|
||||
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
|
||||
String relay = messagePair.first().relay;
|
||||
|
||||
if (Util.isEmpty(relay)) {
|
||||
@@ -199,8 +202,6 @@ public class MessageController extends HttpServlet {
|
||||
success.add(string);
|
||||
for (String string : relayResponse.getFailure())
|
||||
failure.add(string);
|
||||
for (String string : relayResponse.getNumbersMissingDevices())
|
||||
numbersMissingDevices.add(string);
|
||||
} catch (NoSuchPeerException e) {
|
||||
logger.info("No such peer", e);
|
||||
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
|
||||
@@ -208,7 +209,7 @@ public class MessageController extends HttpServlet {
|
||||
}
|
||||
}
|
||||
|
||||
byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices));
|
||||
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
|
||||
response.setContentLength(responseData.length);
|
||||
response.getOutputStream().write(responseData);
|
||||
context.complete();
|
||||
@@ -232,30 +233,16 @@ public class MessageController extends HttpServlet {
|
||||
|
||||
@Nullable
|
||||
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
|
||||
List<IncomingMessage> incomingMessages,
|
||||
List<String> numbersMissingDevices)
|
||||
List<IncomingMessage> incomingMessages)
|
||||
throws MissingDevicesException
|
||||
{
|
||||
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
|
||||
Map<String, Set<Long>> localDestinations = new HashMap<>();
|
||||
|
||||
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
|
||||
|
||||
Set<String> destinationNumbers = new HashSet<>();
|
||||
for (IncomingMessage incoming : incomingMessages) {
|
||||
for (IncomingMessage incoming : incomingMessages)
|
||||
destinationNumbers.add(incoming.getDestination());
|
||||
if (!Util.isEmpty(incoming.getRelay()))
|
||||
continue;
|
||||
|
||||
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
|
||||
if (deviceIds == null) {
|
||||
deviceIds = new HashSet<>();
|
||||
localDestinations.put(incoming.getDestination(), deviceIds);
|
||||
}
|
||||
deviceIds.add(incoming.getDestinationDeviceId());
|
||||
}
|
||||
|
||||
Pair<Map<String, Account>, List<String>> accountsForDevices = accountsManager.getAccountsForDevices(localDestinations);
|
||||
|
||||
Map<String, Account> localAccounts = accountsForDevices.first();
|
||||
for (String number : accountsForDevices.second())
|
||||
numbersMissingDevices.add(number);
|
||||
|
||||
for (IncomingMessage incoming : incomingMessages) {
|
||||
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
|
||||
@@ -281,7 +268,14 @@ public class MessageController extends HttpServlet {
|
||||
if (!Util.isEmpty(incoming.getRelay()))
|
||||
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
|
||||
else {
|
||||
Account destination = localAccounts.get(incoming.getDestination());
|
||||
Account destination = null;
|
||||
for (Account account : localAccounts) {
|
||||
if (account.getNumber().equals(incoming.getDestination())) {
|
||||
destination = account;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (destination != null)
|
||||
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
|
||||
}
|
||||
@@ -293,6 +287,24 @@ public class MessageController extends HttpServlet {
|
||||
return outgoingMessages;
|
||||
}
|
||||
|
||||
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
|
||||
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
|
||||
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
|
||||
Map<String, Set<Long>> localDestinations = new HashMap<>();
|
||||
for (IncomingMessage incoming : incomingMessages) {
|
||||
if (!Util.isEmpty(incoming.getRelay()))
|
||||
continue;
|
||||
|
||||
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
|
||||
if (deviceIds == null) {
|
||||
deviceIds = new HashSet<>();
|
||||
localDestinations.put(incoming.getDestination(), deviceIds);
|
||||
}
|
||||
deviceIds.add(incoming.getDestinationDeviceId());
|
||||
}
|
||||
return localDestinations;
|
||||
}
|
||||
|
||||
private byte[] getMessageBody(IncomingMessage message) {
|
||||
try {
|
||||
return Base64.decode(message.getBody());
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public class MissingDevicesException extends Exception {
|
||||
public Set<String> missingNumbers;
|
||||
public MissingDevicesException(Set<String> missingNumbers) {
|
||||
this.missingNumbers = missingNumbers;
|
||||
}
|
||||
}
|
||||
@@ -16,16 +16,25 @@
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public class MessageResponse {
|
||||
private List<String> success;
|
||||
private List<String> failure;
|
||||
private List<String> missingDeviceIds;
|
||||
private Set<String> missingDeviceIds;
|
||||
|
||||
public MessageResponse(List<String> success, List<String> failure, List<String> missingDeviceIds) {
|
||||
public MessageResponse(List<String> success, List<String> failure) {
|
||||
this.success = success;
|
||||
this.failure = failure;
|
||||
this.missingDeviceIds = new HashSet<>();
|
||||
}
|
||||
|
||||
public MessageResponse(Set<String> missingDeviceIds) {
|
||||
this.success = new LinkedList<>();
|
||||
this.failure = new LinkedList<>(missingDeviceIds);
|
||||
this.missingDeviceIds = missingDeviceIds;
|
||||
}
|
||||
|
||||
@@ -35,11 +44,23 @@ public class MessageResponse {
|
||||
return success;
|
||||
}
|
||||
|
||||
public void setSuccess(List<String> success) {
|
||||
this.success = success;
|
||||
}
|
||||
|
||||
public List<String> getFailure() {
|
||||
return failure;
|
||||
}
|
||||
|
||||
public List<String> getNumbersMissingDevices() {
|
||||
public void setFailure(List<String> failure) {
|
||||
this.failure = failure;
|
||||
}
|
||||
|
||||
public Set<String> getNumbersMissingDevices() {
|
||||
return missingDeviceIds;
|
||||
}
|
||||
|
||||
public void setNumbersMissingDevices(Set<String> numbersMissingDevices) {
|
||||
this.missingDeviceIds = numbersMissingDevices;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.common.base.Optional;
|
||||
import net.spy.memcached.MemcachedClient;
|
||||
import org.whispersystems.textsecuregcm.controllers.MissingDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.ClientContact;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
@@ -119,36 +120,30 @@ public class AccountsManager {
|
||||
return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices));
|
||||
}
|
||||
|
||||
private Map<String, Account> getAllAccounts(Set<String> numbers) {
|
||||
private List<Account> getAllAccounts(Set<String> numbers) {
|
||||
//TODO: ONE QUERY
|
||||
Map<String, Account> result = new HashMap<>();
|
||||
List<Account> accounts = new LinkedList<>();
|
||||
for (String number : numbers) {
|
||||
Optional<Account> account = getAccount(number);
|
||||
if (account.isPresent())
|
||||
result.put(number, account.get());
|
||||
accounts.add(account.get());
|
||||
}
|
||||
return result;
|
||||
return accounts;
|
||||
}
|
||||
|
||||
public Pair<Map<String, Account>, List<String>> getAccountsForDevices(Map<String, Set<Long>> destinations) {
|
||||
List<String> numbersMissingDevices = new LinkedList<>();
|
||||
Map<String, Account> localAccounts = getAllAccounts(destinations.keySet());
|
||||
public List<Account> getAccountsForDevices(Map<String, Set<Long>> destinations) throws MissingDevicesException {
|
||||
Set<String> numbersMissingDevices = new HashSet<>(destinations.keySet());
|
||||
List<Account> localAccounts = getAllAccounts(destinations.keySet());
|
||||
|
||||
for (String number : destinations.keySet()) {
|
||||
if (localAccounts.get(number) == null)
|
||||
numbersMissingDevices.add(number);
|
||||
for (Account account : localAccounts){
|
||||
if (account.hasAllDeviceIds(destinations.get(account.getNumber())))
|
||||
numbersMissingDevices.remove(account.getNumber());
|
||||
}
|
||||
|
||||
Iterator<Account> localAccountIterator = localAccounts.values().iterator();
|
||||
while (localAccountIterator.hasNext()) {
|
||||
Account account = localAccountIterator.next();
|
||||
if (!account.hasAllDeviceIds(destinations.get(account.getNumber()))) {
|
||||
numbersMissingDevices.add(account.getNumber());
|
||||
localAccountIterator.remove();
|
||||
}
|
||||
}
|
||||
if (!numbersMissingDevices.isEmpty())
|
||||
throw new MissingDevicesException(numbersMissingDevices);
|
||||
|
||||
return new Pair<>(localAccounts, numbersMissingDevices);
|
||||
return localAccounts;
|
||||
}
|
||||
|
||||
private void updateDirectory(Device device) {
|
||||
|
||||
Reference in New Issue
Block a user