Initial multi device support refactoring.

1) Store account data as a json type, which includes all
   devices in a single object.

2) Simplify message delivery logic.

3) Make federated calls a pass through to standard controllers.

4) Simplify key retrieval logic.
This commit is contained in:
Moxie Marlinspike
2014-01-18 23:45:07 -08:00
parent 6f9226dcf9
commit 74f71fd8a6
47 changed files with 961 additions and 1211 deletions

View File

@@ -32,8 +32,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@@ -54,7 +54,6 @@ import javax.ws.rs.core.Response;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
@Path("/v1/accounts")
public class AccountController {
@@ -97,7 +96,7 @@ public class AccountController {
rateLimiters.getVoiceDestinationLimiter().validate(number);
break;
default:
throw new WebApplicationException(Response.status(415).build());
throw new WebApplicationException(Response.status(422).build());
}
VerificationCode verificationCode = generateVerificationCode();
@@ -137,14 +136,17 @@ public class AccountController {
}
Device device = new Device();
device.setNumber(number);
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setDeviceId(0);
accounts.create(new Account(number, accountAttributes.getSupportsSms(), device));
Account account = new Account();
account.setNumber(number);
account.setSupportsSms(accountAttributes.getSupportsSms());
account.addDevice(device);
accounts.create(account);
pendingAccounts.remove(number);
@@ -161,36 +163,40 @@ public class AccountController {
@PUT
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Auth Device device, @Valid GcmRegistrationId registrationId) {
device.setApnRegistrationId(null);
device.setGcmRegistrationId(registrationId.getGcmRegistrationId());
accounts.update(device);
public void setGcmRegistrationId(@Auth Account account, @Valid GcmRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
accounts.update(account);
}
@Timed
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Auth Device device) {
device.setGcmRegistrationId(null);
accounts.update(device);
public void deleteGcmRegistrationId(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
device.setGcmId(null);
accounts.update(account);
}
@Timed
@PUT
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Auth Device device, @Valid ApnRegistrationId registrationId) {
device.setApnRegistrationId(registrationId.getApnRegistrationId());
device.setGcmRegistrationId(null);
accounts.update(device);
public void setApnRegistrationId(@Auth Account account, @Valid ApnRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(registrationId.getApnRegistrationId());
device.setGcmId(null);
accounts.update(account);
}
@Timed
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Auth Device device) {
device.setApnRegistrationId(null);
accounts.update(device);
public void deleteApnRegistrationId(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
accounts.update(account);
}
@Timed

View File

@@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.google.common.base.Optional;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
@@ -26,7 +27,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.UrlSigner;
@@ -35,6 +36,7 @@ import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
@@ -64,37 +66,38 @@ public class AttachmentController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public Response allocateAttachment(@Auth Device device) throws RateLimitExceededException {
rateLimiters.getAttachmentLimiter().validate(device.getNumber());
public AttachmentDescriptor allocateAttachment(@Auth Account account)
throws RateLimitExceededException
{
if (account.isRateLimited()) {
rateLimiters.getAttachmentLimiter().validate(account.getNumber());
}
long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT);
AttachmentDescriptor descriptor = new AttachmentDescriptor(attachmentId, url.toExternalForm());
long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT);
return new AttachmentDescriptor(attachmentId, url.toExternalForm());
return Response.ok().entity(descriptor).build();
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{attachmentId}")
public Response redirectToAttachment(@Auth Device device,
@PathParam("attachmentId") long attachmentId,
@QueryParam("relay") String relay)
public AttachmentUri redirectToAttachment(@Auth Account account,
@PathParam("attachmentId") long attachmentId,
@QueryParam("relay") Optional<String> relay)
throws IOException
{
try {
URL url;
if (relay == null) url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET);
else url = federatedClientManager.getClient(relay).getSignedAttachmentUri(attachmentId);
return Response.ok().entity(new AttachmentUri(url)).build();
} catch (IOException e) {
logger.warn("No conectivity", e);
return Response.status(500).build();
if (!relay.isPresent()) {
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET));
} else {
return new AttachmentUri(federatedClientManager.getClient(relay.get()).getSignedAttachmentUri(attachmentId));
}
} catch (NoSuchPeerException e) {
logger.info("No such peer: " + relay);
return Response.status(404).build();
throw new WebApplicationException(Response.status(404).build());
}
}

View File

@@ -26,7 +26,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
@@ -68,13 +70,13 @@ public class DeviceController {
@GET
@Path("/provisioning_code")
@Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Device device)
public VerificationCode createDeviceToken(@Auth Account account)
throws RateLimitExceededException
{
rateLimiters.getVerifyLimiter().validate(device.getNumber()); //TODO: New limiter?
rateLimiters.getVerifyLimiter().validate(account.getNumber()); //TODO: New limiter?
VerificationCode verificationCode = generateVerificationCode();
pendingDevices.store(device.getNumber(), verificationCode.getVerificationCode());
pendingDevices.store(account.getNumber(), verificationCode.getVerificationCode());
return verificationCode;
}
@@ -84,12 +86,11 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Path("/{verification_code}")
public long verifyDeviceToken(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes)
public DeviceResponse verifyDeviceToken(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes)
throws RateLimitExceededException
{
Device device;
try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber();
@@ -105,24 +106,28 @@ public class DeviceController {
throw new WebApplicationException(Response.status(403).build());
}
device = new Device();
device.setNumber(number);
Optional<Account> account = accounts.get(number);
if (!account.isPresent()) {
throw new WebApplicationException(Response.status(403).build());
}
Device device = new Device();
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
accounts.provisionDevice(device);
account.get().addDevice(device);
accounts.update(account.get());
pendingDevices.remove(number);
logger.debug("Stored new device device...");
return new DeviceResponse(device.getId());
} catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad Authorization Header", e);
throw new WebApplicationException(Response.status(401).build());
}
return device.getDeviceId();
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@@ -21,11 +21,11 @@ import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContactTokens;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.util.Base64;
@@ -60,10 +60,10 @@ public class DirectoryController {
@GET
@Path("/{token}")
@Produces(MediaType.APPLICATION_JSON)
public Response getTokenPresence(@Auth Device device, @PathParam("token") String token)
public Response getTokenPresence(@Auth Account account, @PathParam("token") String token)
throws RateLimitExceededException
{
rateLimiters.getContactsLimiter().validate(device.getNumber());
rateLimiters.getContactsLimiter().validate(account.getNumber());
try {
Optional<ClientContact> contact = directory.get(Base64.decodeWithoutPadding(token));
@@ -82,10 +82,10 @@ public class DirectoryController {
@Path("/tokens")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public ClientContacts getContactIntersection(@Auth Device device, @Valid ClientContactTokens contacts)
public ClientContacts getContactIntersection(@Auth Account account, @Valid ClientContactTokens contacts)
throws RateLimitExceededException
{
rateLimiters.getContactsLimiter().validate(device.getNumber(), contacts.getContacts().size());
rateLimiters.getContactsLimiter().validate(account.getNumber(), contacts.getContacts().size());
try {
List<byte[]> tokens = new LinkedList<>();

View File

@@ -16,9 +16,7 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.google.common.base.Optional;
import com.google.protobuf.InvalidProtocolBufferException;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
@@ -27,40 +25,25 @@ import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.google.common.base.Preconditions.checkState;
@Path("/v1/federation")
public class FederationController {
@@ -69,16 +52,20 @@ public class FederationController {
private static final int ACCOUNT_CHUNK_SIZE = 10000;
private final PushSender pushSender;
private final Keys keys;
private final AccountsManager accounts;
private final UrlSigner urlSigner;
private final AccountsManager accounts;
private final AttachmentController attachmentController;
private final KeysController keysController;
private final MessageController messageController;
public FederationController(Keys keys, AccountsManager accounts, PushSender pushSender, UrlSigner urlSigner) {
this.keys = keys;
this.accounts = accounts;
this.pushSender = pushSender;
this.urlSigner = urlSigner;
public FederationController(AccountsManager accounts,
AttachmentController attachmentController,
KeysController keysController,
MessageController messageController)
{
this.accounts = accounts;
this.attachmentController = attachmentController;
this.keysController = keysController;
this.messageController = messageController;
}
@Timed
@@ -87,82 +74,61 @@ public class FederationController {
@Produces(MediaType.APPLICATION_JSON)
public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer,
@PathParam("attachmentId") long attachmentId)
throws IOException
{
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET);
return new AttachmentUri(url);
return attachmentController.redirectToAttachment(new NonLimitedAccount("Unknown", peer.getName()),
attachmentId, Optional.<String>absent());
}
@Timed
@GET
@Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
public PreKey getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
throws IOException
{
Optional<Account> account = accounts.getAccount(number);
UnstructuredPreKeyList keyList = null;
if (account.isPresent())
keyList = keys.get(number, account.get());
if (!account.isPresent() || keyList.getKeys().isEmpty())
throw new WebApplicationException(Response.status(404).build());
return keyList;
try {
return keysController.get(new NonLimitedAccount("Unknown", peer.getName()), number, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKeys(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysController.getDeviceKey(new NonLimitedAccount("Unknown", peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@PUT
@Path("/message")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse relayMessage(@Auth FederatedPeer peer, @Valid List<RelayMessage> messages)
@Path("/messages/{source}/{destination}")
public void sendMessages(@Auth FederatedPeer peer,
@PathParam("source") String source,
@PathParam("destination") String destination,
@Valid IncomingMessageList messages)
throws IOException
{
try {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (RelayMessage message : messages) {
Set<Long> deviceIds = localDestinations.get(message.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(message.getDestination(), deviceIds);
}
deviceIds.add(message.getDestinationDeviceId());
}
List<Account> localAccounts = null;
try {
localAccounts = accounts.getAccountsForDevices(localDestinations);
} catch (MissingDevicesException e) {
return new MessageResponse(e.missingNumbers);
}
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
for (RelayMessage message : messages) {
Account destinationAccount = null;
for (Account account : localAccounts)
if (account.getNumber().equals(message.getDestination()))
destinationAccount= account;
checkState(destinationAccount != null);
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
.toBuilder()
.setRelay(peer.getName())
.build();
try {
pushSender.sendMessage(device, signal);
success.add(device.getBackwardsCompatibleNumberEncoding());
} catch (NoSuchUserException e) {
logger.info("No such user", e);
failure.add(device.getBackwardsCompatibleNumberEncoding());
}
}
return new MessageResponse(success, failure);
} catch (InvalidProtocolBufferException ipe) {
logger.warn("ProtoBuf", ipe);
throw new WebApplicationException(Response.status(400).build());
messages.setRelay(null);
messageController.sendMessage(new NonLimitedAccount(source, peer.getName()), destination, messages);
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@@ -181,15 +147,16 @@ public class FederationController {
public ClientContacts getUserTokens(@Auth FederatedPeer peer,
@PathParam("offset") int offset)
{
List<Device> numberList = accounts.getAllMasterDevices(offset, ACCOUNT_CHUNK_SIZE);
List<Account> accountList = accounts.getAll(offset, ACCOUNT_CHUNK_SIZE);
List<ClientContact> clientContacts = new LinkedList<>();
for (Device device : numberList) {
byte[] token = Util.getContactToken(device.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms());
for (Account account : accountList) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
if (!device.isActive())
if (!account.isActive()) {
clientContact.setInactive(true);
}
clientContacts.add(clientContact);
}

View File

@@ -29,7 +29,6 @@ import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
@@ -43,7 +42,6 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.List;
@Path("/v1/keys")
public class KeysController {
@@ -52,46 +50,47 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final AccountsManager accountsManager;
private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accountsManager,
public KeysController(RateLimiters rateLimiters, Keys keys,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager;
}
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Device device, @Valid PreKeyList preKeys) {
keys.store(device.getNumber(), device.getDeviceId(), preKeys.getLastResortKey(), preKeys.getKeys());
public void setKeys(@Auth Account account, @Valid PreKeyList preKeys) {
Device device = account.getAuthenticatedDevice().get();
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
}
private List<PreKey> getKeys(Device device, String number, String relay) throws RateLimitExceededException
@Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
rateLimiters.getPreKeysLimiter().validate(device.getNumber() + "__" + number);
try {
UnstructuredPreKeyList keyList;
if (relay == null) {
Optional<Account> account = accountsManager.getAccount(number);
if (account.isPresent())
keyList = keys.get(number, account.get());
else
throw new WebApplicationException(Response.status(404).build());
} else {
keyList = federatedClientManager.getClient(relay).getKeys(number);
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (keyList == null || keyList.getKeys().isEmpty()) throw new WebApplicationException(Response.status(404).build());
else return keyList.getKeys();
Optional<UnstructuredPreKeyList> results;
if (!relay.isPresent()) results = getLocalKeys(number, deviceId);
else results = federatedClientManager.getClient(relay.get()).getKeys(number, deviceId);
if (results.isPresent()) return results.get();
else throw new WebApplicationException(Response.status(404).build());
} catch (NoSuchPeerException e) {
logger.info("No peer: " + relay);
throw new WebApplicationException(Response.status(404).build());
}
}
@@ -100,15 +99,27 @@ public class KeysController {
@GET
@Path("/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Response get(@Auth Device device,
@PathParam("number") String number,
@QueryParam("multikeys") Optional<String> multikey,
@QueryParam("relay") String relay)
public PreKey get(@Auth Account account,
@PathParam("number") String number,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
if (!multikey.isPresent())
return Response.ok(getKeys(device, number, relay).get(0)).type(MediaType.APPLICATION_JSON).build();
else
return Response.ok(getKeys(device, number, relay)).type(MediaType.APPLICATION_JSON).build();
UnstructuredPreKeyList results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
return results.getKeys().get(0);
}
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceId) {
try {
if (deviceId.equals("*")) {
return keys.get(number);
}
Optional<PreKey> targetKey = keys.get(number, Long.parseLong(deviceId));
if (targetKey.isPresent()) return Optional.of(new UnstructuredPreKeyList(targetKey.get()));
else return Optional.absent();
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
}

View File

@@ -16,383 +16,226 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.google.protobuf.ByteString;
import com.yammer.dropwizard.auth.AuthenticationException;
import com.yammer.dropwizard.auth.basic.BasicCredentials;
import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import com.yammer.metrics.core.Timer;
import com.yammer.metrics.core.TimerContext;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.MissingDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class MessageController extends HttpServlet {
@Path("/v1/messages")
public class MessageController {
public static final String PATH = "/v1/messages/";
private final Meter successMeter = Metrics.newMeter(MessageController.class, "deliver_message", "success", TimeUnit.MINUTES);
private final Meter failureMeter = Metrics.newMeter(MessageController.class, "deliver_message", "failure", TimeUnit.MINUTES);
private final Timer timer = Metrics.newTimer(MessageController.class, "deliver_message_time", TimeUnit.MILLISECONDS, TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters;
private final DeviceAuthenticator deviceAuthenticator;
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
private final ObjectMapper objectMapper;
private final ExecutorService executor;
private final AccountsManager accountsManager;
public MessageController(RateLimiters rateLimiters,
DeviceAuthenticator deviceAuthenticator,
PushSender pushSender,
AccountsManager accountsManager,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.deviceAuthenticator = deviceAuthenticator;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager;
this.objectMapper = new ObjectMapper();
this.executor = Executors.newFixedThreadPool(10);
}
class LocalOrRemoteDevice {
Device device;
String relay, number; long deviceId;
LocalOrRemoteDevice(Device device) {
this.device = device; this.number = device.getNumber(); this.deviceId = device.getDeviceId();
}
LocalOrRemoteDevice(String relay, String number, long deviceId) {
this.relay = relay; this.number = number; this.deviceId = deviceId;
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
TimerContext timerContext = timer.time();
@Timed
@Path("/{destination}")
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void sendMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
try {
Device sender = authenticate(req);
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(401);
} catch (ValidationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(415);
} catch (IOException e) {
logger.warn("IOE", e);
failureMeter.mark();
timerContext.stop();
resp.setStatus(501);
} catch (RateLimitExceededException e) {
timerContext.stop();
failureMeter.mark();
resp.setStatus(413);
if (messages.getRelay() != null) sendLocalMessage(source, destinationName, messages);
else sendRelayMessage(source, destinationName, messages);
} catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
} catch (MissingDevicesException e) {
throw new WebApplicationException(Response.status(409)
.entity(new MissingDevices(e.getMissingDevices()))
.build());
}
}
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final Device sender,
final IncomingMessageList messages)
@Timed
@Path("/")
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse sendMessageLegacy(@Auth Account source, @Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
executor.submit(new Runnable() {
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try {
List<IncomingMessage> incomingMessages = messages.getMessages();
validateLegacyDestinations(incomingMessages);
try {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
try {
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
} catch (MissingDevicesException e) {
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
failureMeter.mark();
timerContext.stop();
return;
}
messages.setRelay(incomingMessages.get(0).getRelay());
sendMessage(source, incomingMessages.get(0).getDestination(), messages);
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
String relay = messagePair.first().relay;
if (Util.isEmpty(relay)) {
String encodedId = messagePair.first().device.getBackwardsCompatibleNumberEncoding();
try {
pushSender.sendMessage(messagePair.first().device, messagePair.second());
success.add(encodedId);
} catch (NoSuchUserException e) {
logger.debug("No such user", e);
failure.add(encodedId);
}
} else {
Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> messageSet = relayMessages.get(relay);
if (messageSet == null) {
messageSet = new HashSet<>();
relayMessages.put(relay, messageSet);
}
messageSet.add(messagePair);
}
}
for (Map.Entry<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> messagesForRelay : relayMessages.entrySet()) {
try {
FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey());
List<RelayMessage> messages = new LinkedList<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> message : messagesForRelay.getValue()) {
messages.add(new RelayMessage(message.first().number,
message.first().deviceId,
message.second().toByteArray()));
}
MessageResponse relayResponse = client.sendMessages(messages);
for (String string : relayResponse.getSuccess())
success.add(string);
for (String string : relayResponse.getFailure())
failure.add(string);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
failure.add(messagePair.first().number);
}
}
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
successMeter.mark();
} catch (IOException e) {
logger.warn("Async Handler", e);
failureMeter.mark();
response.setStatus(501);
context.complete();
} catch (Exception e) {
logger.error("Unknown error sending message", e);
failureMeter.mark();
response.setStatus(500);
context.complete();
}
timerContext.stop();
}
});
return new MessageResponse(new LinkedList<String>(), new LinkedList<String>());
} catch (ValidationException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
@Nullable
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages)
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws NoSuchUserException, MissingDevicesException, IOException
{
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) {
sendLocalMessage(source, destination, destinationDevice.get(), incomingMessage);
}
}
}
private void sendLocalMessage(Account source,
Account destinationAccount,
Device destinationDevice,
IncomingMessage incomingMessage)
throws NoSuchUserException, IOException
{
try {
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
OutgoingMessageSignal.Builder messageBuilder = OutgoingMessageSignal.newBuilder();
messageBuilder.setType(incomingMessage.getType())
.setSource(source.getNumber())
.setTimestamp(System.currentTimeMillis());
if (messageBody.isPresent()) {
messageBuilder.setMessage(ByteString.copyFrom(messageBody.get()));
}
if (source.getRelay().isPresent()) {
messageBuilder.setRelay(source.getRelay().get());
}
pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build());
} catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) throw new NoSuchUserException(e);
else logger.debug("Not registered", e);
} catch (TransientPushFailureException e) {
if (destinationDevice.isMaster()) throw new IOException(e);
else logger.debug("Transient failure", e);
}
}
private void sendRelayMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws IOException, NoSuchUserException
{
try {
FederatedClient client = federatedClientManager.getClient(messages.getRelay());
client.sendMessages(source.getNumber(), destinationName, messages);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountsManager.get(destination);
if (!account.isPresent() || !account.get().isActive()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
throws MissingDevicesException
{
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
Set<Long> destinationDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
Set<String> destinationNumbers = new HashSet<>();
for (IncomingMessage incoming : incomingMessages)
destinationNumbers.add(incoming.getDestination());
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
outgoingMessage.setType(incoming.getType());
outgoingMessage.setSource(sourceNumber);
byte[] messageBody = getMessageBody(incoming);
if (messageBody != null) {
outgoingMessage.setMessage(ByteString.copyFrom(messageBody));
}
outgoingMessage.setTimestamp(System.currentTimeMillis());
for (String destination : destinationNumbers) {
if (!destination.equals(incoming.getDestination()))
outgoingMessage.addDestinations(destination);
}
LocalOrRemoteDevice device = null;
if (!Util.isEmpty(incoming.getRelay()))
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
else {
Account destination = null;
for (Account account : localAccounts) {
if (account.getNumber().equals(incoming.getDestination())) {
destination = account;
break;
}
}
if (destination != null)
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
}
if (device != null)
outgoingMessages.add(new Pair<>(device, outgoingMessage.build()));
for (IncomingMessage message : messages) {
destinationDeviceIds.add(message.getDestinationDeviceId());
}
return outgoingMessages;
}
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
for (Device device : account.getDevices()) {
if (!destinationDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
deviceIds.add(incoming.getDestinationDeviceId());
}
return localDestinations;
}
private byte[] getMessageBody(IncomingMessage message) {
try {
return Base64.decode(message.getBody());
} catch (IOException ioe) {
ioe.printStackTrace();
return null;
if (!missingDeviceIds.isEmpty()) {
throw new MissingDevicesException(missingDeviceIds);
}
}
private byte[] serializeResponse(MessageResponse response) throws IOException {
try {
return objectMapper.writeValueAsBytes(response);
} catch (JsonProcessingException e) {
throw new IOException(e);
}
}
private IncomingMessageList parseIncomingMessages(HttpServletRequest request)
throws IOException, ValidationException
private void validateLegacyDestinations(List<IncomingMessage> messages)
throws ValidationException
{
BufferedReader reader = request.getReader();
StringBuilder content = new StringBuilder();
String line;
String destination = null;
while ((line = reader.readLine()) != null) {
content.append(line);
for (IncomingMessage message : messages) {
if (destination != null && !destination.equals(message.getDestination())) {
throw new ValidationException("Multiple account destinations!");
}
destination = message.getDestination();
}
IncomingMessageList messages = objectMapper.readValue(content.toString(),
IncomingMessageList.class);
if (messages.getMessages() == null) {
throw new ValidationException();
}
for (IncomingMessage message : messages.getMessages()) {
if (message.getBody() == null) throw new ValidationException();
if (message.getDestination() == null) throw new ValidationException();
}
return messages;
}
private Device authenticate(HttpServletRequest request) throws AuthenticationException {
private Optional<byte[]> getMessageBody(IncomingMessage message) {
try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromFullHeader(request.getHeader("Authorization"));
BasicCredentials credentials = new BasicCredentials(authorizationHeader.getNumber() + "." + authorizationHeader.getDeviceId(),
authorizationHeader.getPassword() );
Optional<Device> account = deviceAuthenticator.authenticate(credentials);
if (account.isPresent()) return account.get();
else throw new AuthenticationException("Bad credentials");
} catch (InvalidAuthorizationHeaderException e) {
throw new AuthenticationException(e);
return Optional.of(Base64.decode(message.getBody()));
} catch (IOException ioe) {
logger.debug("Bad B64", ioe);
return Optional.absent();
}
}
// @Timed
// @POST
// @Consumes(MediaType.APPLICATION_JSON)
// @Produces(MediaType.APPLICATION_JSON)
// public MessageResponse sendMessage(@Auth Device sender, IncomingMessageList messages)
// throws IOException
// {
// List<String> success = new LinkedList<>();
// List<String> failure = new LinkedList<>();
// List<IncomingMessage> incomingMessages = messages.getMessages();
// List<OutgoingMessageSignal> outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), incomingMessages);
//
// IterablePair<IncomingMessage, OutgoingMessageSignal> listPair = new IterablePair<>(incomingMessages, outgoingMessages);
//
// for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : listPair) {
// String destination = messagePair.first().getDestination();
// String relay = messagePair.first().getRelay();
//
// try {
// if (Util.isEmpty(relay)) sendLocalMessage(destination, messagePair.second());
// else sendRelayMessage(relay, destination, messagePair.second());
// success.add(destination);
// } catch (NoSuchUserException e) {
// logger.debug("No such user", e);
// failure.add(destination);
// }
// }
//
// return new MessageResponse(success, failure);
// }
}

View File

@@ -4,8 +4,13 @@ import java.util.List;
import java.util.Set;
public class MissingDevicesException extends Exception {
public Set<String> missingNumbers;
public MissingDevicesException(Set<String> missingNumbers) {
this.missingNumbers = missingNumbers;
private final List<Long> missingDevices;
public MissingDevicesException(List<Long> missingDevices) {
this.missingDevices = missingDevices;
}
public List<Long> getMissingDevices() {
return missingDevices;
}
}

View File

@@ -18,4 +18,7 @@ package org.whispersystems.textsecuregcm.controllers;
public class ValidationException extends Exception {
public ValidationException(String s) {
super(s);
}
}