Add support for "registrationId" session enforcement.

This commit is contained in:
Moxie Marlinspike
2014-02-20 09:32:42 -08:00
parent 35e212a30f
commit f4ecb5d7be
18 changed files with 204 additions and 32 deletions

View File

@@ -131,7 +131,7 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
accountsManager);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, federatedClientManager);
KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
environment.addProvider(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),

View File

@@ -140,6 +140,7 @@ public class AccountController {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
Account account = new Account();
account.setNumber(number);

View File

@@ -28,6 +28,7 @@ import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
@@ -42,6 +43,8 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List;
@Path("/v1/keys")
public class KeysController {
@@ -50,13 +53,15 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final AccountsManager accounts;
private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys,
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accounts = accounts;
this.federatedClientManager = federatedClientManager;
}
@@ -108,18 +113,50 @@ public class KeysController {
return results.getKeys().get(0);
}
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceId) {
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceIdSelector) {
Optional<Account> destination = accounts.get(number);
if (!destination.isPresent() || !destination.get().isActive()) {
return Optional.absent();
}
try {
if (deviceId.equals("*")) {
return keys.get(number);
if (deviceIdSelector.equals("*")) {
Optional<UnstructuredPreKeyList> preKeys = keys.get(number);
return getActiveKeys(destination.get(), preKeys);
}
Optional<PreKey> targetKey = keys.get(number, Long.parseLong(deviceId));
long deviceId = Long.parseLong(deviceIdSelector);
Optional<Device> targetDevice = destination.get().getDevice(deviceId);
if (targetKey.isPresent()) return Optional.of(new UnstructuredPreKeyList(targetKey.get()));
else return Optional.absent();
if (!targetDevice.isPresent() || !targetDevice.get().isActive()) {
return Optional.absent();
}
Optional<UnstructuredPreKeyList> preKeys = keys.get(number, deviceId);
return getActiveKeys(destination.get(), preKeys);
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
private Optional<UnstructuredPreKeyList> getActiveKeys(Account destination,
Optional<UnstructuredPreKeyList> preKeys)
{
if (!preKeys.isPresent()) return Optional.absent();
List<PreKey> filteredKeys = new LinkedList<>();
for (PreKey preKey : preKeys.get().getKeys()) {
Optional<Device> device = destination.getDevice(preKey.getDeviceId());
if (device.isPresent() && device.get().isActive()) {
preKey.setRegistrationId(device.get().getRegistrationId());
filteredKeys.add(preKey);
}
}
if (filteredKeys.isEmpty()) return Optional.absent();
else return Optional.of(new UnstructuredPreKeyList(filteredKeys));
}
}

View File

@@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
@@ -98,6 +99,11 @@ public class MessageController {
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
}
@@ -124,11 +130,12 @@ public class MessageController {
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws NoSuchUserException, MismatchedDevicesException, IOException
throws NoSuchUserException, MismatchedDevicesException, IOException, StaleDevicesException
{
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
validateRegistrationIds(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
@@ -197,6 +204,27 @@ public class MessageController {
return account.get();
}
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException
{
List<Long> staleDevices = new LinkedList<>();
for (IncomingMessage message : messages) {
Optional<Device> device = account.getDevice(message.getDestinationDeviceId());
if (device.isPresent() &&
message.getDestinationRegistrationId() > 0 &&
message.getDestinationRegistrationId() != device.get().getRegistrationId())
{
staleDevices.add(device.get().getId());
}
}
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
throws MismatchedDevicesException
{
@@ -211,10 +239,12 @@ public class MessageController {
}
for (Device device : account.getDevices()) {
accountDeviceIds.add(device.getId());
if (device.isActive()) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}

View File

@@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Throwable {
private final List<Long> staleDevices;
public StaleDevicesException(List<Long> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Long> getStaleDevices() {
return staleDevices;
}
}

View File

@@ -31,12 +31,16 @@ public class AccountAttributes {
@JsonProperty
private boolean fetchesMessages;
@JsonProperty
private int registrationId;
public AccountAttributes() {}
public AccountAttributes(String signalingKey, boolean supportsSms, boolean fetchesMessages) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
public AccountAttributes(String signalingKey, boolean supportsSms, boolean fetchesMessages, int registrationId) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
}
public String getSignalingKey() {
@@ -51,4 +55,7 @@ public class AccountAttributes {
return fetchesMessages;
}
public int getRegistrationId() {
return registrationId;
}
}

View File

@@ -30,6 +30,9 @@ public class IncomingMessage {
@JsonProperty
private long destinationDeviceId = 1;
@JsonProperty
private int destinationRegistrationId;
@JsonProperty
@NotEmpty
private String body;
@@ -40,6 +43,7 @@ public class IncomingMessage {
@JsonProperty
private long timestamp;
public String getDestination() {
return destination;
}
@@ -59,4 +63,8 @@ public class IncomingMessage {
public long getDestinationDeviceId() {
return destinationDeviceId;
}
public int getDestinationRegistrationId() {
return destinationRegistrationId;
}
}

View File

@@ -52,6 +52,9 @@ public class PreKey {
@JsonProperty
private boolean lastResort;
@JsonProperty
private int registrationId;
public PreKey() {}
public PreKey(long id, String number, long deviceId, long keyId,
@@ -125,4 +128,12 @@ public class PreKey {
public long getDeviceId() {
return deviceId;
}
public int getRegistrationId() {
return registrationId;
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
}

View File

@@ -0,0 +1,18 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
public class StaleDevices {
@JsonProperty
private List<Long> staleDevices;
public StaleDevices() {}
public StaleDevices(List<Long> staleDevices) {
this.staleDevices = staleDevices;
}
}

View File

@@ -48,11 +48,14 @@ public class Device implements Serializable {
@JsonProperty
private boolean fetchesMessages;
@JsonProperty
private int registrationId;
public Device() {}
public Device(long id, String authToken, String salt,
String signalingKey, String gcmId, String apnId,
boolean fetchesMessages)
boolean fetchesMessages, int registrationId)
{
this.id = id;
this.authToken = authToken;
@@ -61,6 +64,7 @@ public class Device implements Serializable {
this.gcmId = gcmId;
this.apnId = apnId;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
}
public String getApnId() {
@@ -119,4 +123,12 @@ public class Device implements Serializable {
public boolean isMaster() {
return getId() == MASTER_ID;
}
public int getRegistrationId() {
return registrationId;
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
}

View File

@@ -84,14 +84,14 @@ public abstract class Keys {
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<PreKey> get(String number, long deviceId) {
public Optional<UnstructuredPreKeyList> get(String number, long deviceId) {
PreKey preKey = retrieveFirst(number, deviceId);
if (preKey != null && !preKey.isLastResort()) {
removeKey(preKey.getId());
}
if (preKey != null) return Optional.of(preKey);
if (preKey != null) return Optional.of(new UnstructuredPreKeyList(preKey));
else return Optional.absent();
}