Handle simultaneous initiate protocol case.

1) Modify SessionRecord to store a list of "previous" sessions
   in addition to the current active session.  Previous sessions
   can be used for receiving messages, but not for sending
   messages.

2) When a possible "simultaneous initiate" is detected, push the
   current session onto the "previous session" stack instead of
   clearing it and starting over.

3) Additionally, mark the new session created on a received
   possible "simultaneous initiate" as stale for sending.  The
   next outgoing message would trigger a full prekey refresh.

4) Work to do: outgoing messages on the SMS transport should
   probably not use the existing session if it's marked stale
   for sending.  These messages need to fail and notify the user,
   similar to how we'll handle SMS fallback to push users before
   a prekey session is created.
This commit is contained in:
Moxie Marlinspike
2014-03-17 17:24:00 -07:00
parent edc20883eb
commit 926d3c929f
12 changed files with 1435 additions and 514 deletions

View File

@@ -1,6 +1,7 @@
package org.whispersystems.textsecure.crypto;
import android.content.Context;
import android.util.Log;
import android.util.Pair;
import org.whispersystems.textsecure.crypto.ecc.Curve;
@@ -14,10 +15,12 @@ import org.whispersystems.textsecure.crypto.ratchet.MessageKeys;
import org.whispersystems.textsecure.crypto.ratchet.RootKey;
import org.whispersystems.textsecure.storage.RecipientDevice;
import org.whispersystems.textsecure.storage.SessionRecordV2;
import org.whispersystems.textsecure.storage.SessionState;
import org.whispersystems.textsecure.util.Conversions;
import java.security.InvalidAlgorithmParameterException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
@@ -45,27 +48,28 @@ public class SessionCipherV2 extends SessionCipher {
public CiphertextMessage encrypt(byte[] paddedMessage) {
synchronized (SESSION_LOCK) {
SessionRecordV2 sessionRecord = getSessionRecord();
ChainKey chainKey = sessionRecord.getSenderChainKey();
SessionState sessionState = sessionRecord.getSessionState();
ChainKey chainKey = sessionState.getSenderChainKey();
MessageKeys messageKeys = chainKey.getMessageKeys();
ECPublicKey senderEphemeral = sessionRecord.getSenderEphemeral();
int previousCounter = sessionRecord.getPreviousCounter();
ECPublicKey senderEphemeral = sessionState.getSenderEphemeral();
int previousCounter = sessionState.getPreviousCounter();
byte[] ciphertextBody = getCiphertext(messageKeys, paddedMessage);
CiphertextMessage ciphertextMessage = new WhisperMessageV2(messageKeys.getMacKey(),
senderEphemeral, chainKey.getIndex(),
previousCounter, ciphertextBody);
if (sessionRecord.hasPendingPreKey()) {
Pair<Integer, ECPublicKey> pendingPreKey = sessionRecord.getPendingPreKey();
int localRegistrationId = sessionRecord.getLocalRegistrationId();
if (sessionState.hasPendingPreKey()) {
Pair<Integer, ECPublicKey> pendingPreKey = sessionState.getPendingPreKey();
int localRegistrationId = sessionState.getLocalRegistrationId();
ciphertextMessage = new PreKeyWhisperMessage(localRegistrationId, pendingPreKey.first,
pendingPreKey.second,
sessionRecord.getLocalIdentityKey(),
sessionState.getLocalIdentityKey(),
(WhisperMessageV2) ciphertextMessage);
}
sessionRecord.setSenderChainKey(chainKey.getNextChainKey());
sessionState.setSenderChainKey(chainKey.getNextChainKey());
sessionRecord.save();
return ciphertextMessage;
@@ -75,50 +79,83 @@ public class SessionCipherV2 extends SessionCipher {
@Override
public byte[] decrypt(byte[] decodedMessage) throws InvalidMessageException {
synchronized (SESSION_LOCK) {
SessionRecordV2 sessionRecord = getSessionRecord();
WhisperMessageV2 ciphertextMessage = new WhisperMessageV2(decodedMessage);
ECPublicKey theirEphemeral = ciphertextMessage.getSenderEphemeral();
int counter = ciphertextMessage.getCounter();
ChainKey chainKey = getOrCreateChainKey(sessionRecord, theirEphemeral);
MessageKeys messageKeys = getOrCreateMessageKeys(sessionRecord, theirEphemeral,
chainKey, counter);
SessionRecordV2 sessionRecord = getSessionRecord();
SessionState sessionState = sessionRecord.getSessionState();
List<SessionState> previousStates = sessionRecord.getPreviousSessions();
ciphertextMessage.verifyMac(messageKeys.getMacKey());
try {
byte[] plaintext = decrypt(sessionState, decodedMessage);
sessionRecord.save();
byte[] plaintext = getPlaintext(messageKeys, ciphertextMessage.getBody());
return plaintext;
} catch (InvalidMessageException e) {
Log.w("SessionCipherV2", e);
}
sessionRecord.clearPendingPreKey();
sessionRecord.save();
for (SessionState previousState : previousStates) {
try {
byte[] plaintext = decrypt(previousState, decodedMessage);
sessionRecord.save();
return plaintext;
return plaintext;
} catch (InvalidMessageException e) {
Log.w("SessionCipherV2", e);
}
}
throw new InvalidMessageException("No valid sessions.");
}
}
public byte[] decrypt(SessionState sessionState, byte[] decodedMessage)
throws InvalidMessageException
{
if (!sessionState.hasSenderChain()) {
throw new InvalidMessageException("Uninitialized session!");
}
WhisperMessageV2 ciphertextMessage = new WhisperMessageV2(decodedMessage);
ECPublicKey theirEphemeral = ciphertextMessage.getSenderEphemeral();
int counter = ciphertextMessage.getCounter();
ChainKey chainKey = getOrCreateChainKey(sessionState, theirEphemeral);
MessageKeys messageKeys = getOrCreateMessageKeys(sessionState, theirEphemeral,
chainKey, counter);
ciphertextMessage.verifyMac(messageKeys.getMacKey());
byte[] plaintext = getPlaintext(messageKeys, ciphertextMessage.getBody());
sessionState.clearPendingPreKey();
return plaintext;
}
@Override
public int getRemoteRegistrationId() {
synchronized (SESSION_LOCK) {
SessionRecordV2 sessionRecord = getSessionRecord();
return sessionRecord.getRemoteRegistrationId();
return sessionRecord.getSessionState().getRemoteRegistrationId();
}
}
private ChainKey getOrCreateChainKey(SessionRecordV2 sessionRecord, ECPublicKey theirEphemeral)
private ChainKey getOrCreateChainKey(SessionState sessionState, ECPublicKey theirEphemeral)
throws InvalidMessageException
{
try {
if (sessionRecord.hasReceiverChain(theirEphemeral)) {
return sessionRecord.getReceiverChainKey(theirEphemeral);
if (sessionState.hasReceiverChain(theirEphemeral)) {
return sessionState.getReceiverChainKey(theirEphemeral);
} else {
RootKey rootKey = sessionRecord.getRootKey();
ECKeyPair ourEphemeral = sessionRecord.getSenderEphemeralPair();
RootKey rootKey = sessionState.getRootKey();
ECKeyPair ourEphemeral = sessionState.getSenderEphemeralPair();
Pair<RootKey, ChainKey> receiverChain = rootKey.createChain(theirEphemeral, ourEphemeral);
ECKeyPair ourNewEphemeral = Curve.generateKeyPairForType(Curve.DJB_TYPE);
Pair<RootKey, ChainKey> senderChain = receiverChain.first.createChain(theirEphemeral, ourNewEphemeral);
sessionRecord.setRootKey(senderChain.first);
sessionRecord.addReceiverChain(theirEphemeral, receiverChain.second);
sessionRecord.setPreviousCounter(sessionRecord.getSenderChainKey().getIndex()-1);
sessionRecord.setSenderChain(ourNewEphemeral, senderChain.second);
sessionState.setRootKey(senderChain.first);
sessionState.addReceiverChain(theirEphemeral, receiverChain.second);
sessionState.setPreviousCounter(sessionState.getSenderChainKey().getIndex()-1);
sessionState.setSenderChain(ourNewEphemeral, senderChain.second);
return receiverChain.second;
}
@@ -127,14 +164,14 @@ public class SessionCipherV2 extends SessionCipher {
}
}
private MessageKeys getOrCreateMessageKeys(SessionRecordV2 sessionRecord,
private MessageKeys getOrCreateMessageKeys(SessionState sessionState,
ECPublicKey theirEphemeral,
ChainKey chainKey, int counter)
throws InvalidMessageException
{
if (chainKey.getIndex() > counter) {
if (sessionRecord.hasMessageKeys(theirEphemeral, counter)) {
return sessionRecord.removeMessageKeys(theirEphemeral, counter);
if (sessionState.hasMessageKeys(theirEphemeral, counter)) {
return sessionState.removeMessageKeys(theirEphemeral, counter);
} else {
throw new InvalidMessageException("Received message with old counter!");
}
@@ -146,11 +183,11 @@ public class SessionCipherV2 extends SessionCipher {
while (chainKey.getIndex() < counter) {
MessageKeys messageKeys = chainKey.getMessageKeys();
sessionRecord.setMessageKeys(theirEphemeral, messageKeys);
sessionState.setMessageKeys(theirEphemeral, messageKeys);
chainKey = chainKey.getNextChainKey();
}
sessionRecord.setReceiverChainKey(theirEphemeral, chainKey.getNextChainKey());
sessionState.setReceiverChainKey(theirEphemeral, chainKey.getNextChainKey());
return chainKey.getMessageKeys();
}

View File

@@ -10,14 +10,14 @@ import org.whispersystems.textsecure.crypto.ecc.ECKeyPair;
import org.whispersystems.textsecure.crypto.ecc.ECPublicKey;
import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets;
import org.whispersystems.textsecure.crypto.kdf.HKDF;
import org.whispersystems.textsecure.storage.SessionRecordV2;
import org.whispersystems.textsecure.storage.SessionState;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class RatchetingSession {
public static void initializeSession(SessionRecordV2 sessionRecord,
public static void initializeSession(SessionState sessionState,
ECKeyPair ourBaseKey,
ECPublicKey theirBaseKey,
ECKeyPair ourEphemeralKey,
@@ -27,48 +27,48 @@ public class RatchetingSession {
throws InvalidKeyException
{
if (isAlice(ourBaseKey.getPublicKey(), theirBaseKey, ourEphemeralKey.getPublicKey(), theirEphemeralKey)) {
initializeSessionAsAlice(sessionRecord, ourBaseKey, theirBaseKey, theirEphemeralKey,
initializeSessionAsAlice(sessionState, ourBaseKey, theirBaseKey, theirEphemeralKey,
ourIdentityKey, theirIdentityKey);
} else {
initializeSessionAsBob(sessionRecord, ourBaseKey, theirBaseKey,
initializeSessionAsBob(sessionState, ourBaseKey, theirBaseKey,
ourEphemeralKey, ourIdentityKey, theirIdentityKey);
}
}
private static void initializeSessionAsAlice(SessionRecordV2 sessionRecord,
private static void initializeSessionAsAlice(SessionState sessionState,
ECKeyPair ourBaseKey, ECPublicKey theirBaseKey,
ECPublicKey theirEphemeralKey,
IdentityKeyPair ourIdentityKey,
IdentityKey theirIdentityKey)
throws InvalidKeyException
{
sessionRecord.setRemoteIdentityKey(theirIdentityKey);
sessionRecord.setLocalIdentityKey(ourIdentityKey.getPublicKey());
sessionState.setRemoteIdentityKey(theirIdentityKey);
sessionState.setLocalIdentityKey(ourIdentityKey.getPublicKey());
ECKeyPair sendingKey = Curve.generateKeyPairForType(ourIdentityKey.getPublicKey().getPublicKey().getType());
Pair<RootKey, ChainKey> receivingChain = calculate3DHE(true, ourBaseKey, theirBaseKey, ourIdentityKey, theirIdentityKey);
Pair<RootKey, ChainKey> sendingChain = receivingChain.first.createChain(theirEphemeralKey, sendingKey);
sessionRecord.addReceiverChain(theirEphemeralKey, receivingChain.second);
sessionRecord.setSenderChain(sendingKey, sendingChain.second);
sessionRecord.setRootKey(sendingChain.first);
sessionState.addReceiverChain(theirEphemeralKey, receivingChain.second);
sessionState.setSenderChain(sendingKey, sendingChain.second);
sessionState.setRootKey(sendingChain.first);
}
private static void initializeSessionAsBob(SessionRecordV2 sessionRecord,
private static void initializeSessionAsBob(SessionState sessionState,
ECKeyPair ourBaseKey, ECPublicKey theirBaseKey,
ECKeyPair ourEphemeralKey,
IdentityKeyPair ourIdentityKey,
IdentityKey theirIdentityKey)
throws InvalidKeyException
{
sessionRecord.setRemoteIdentityKey(theirIdentityKey);
sessionRecord.setLocalIdentityKey(ourIdentityKey.getPublicKey());
sessionState.setRemoteIdentityKey(theirIdentityKey);
sessionState.setLocalIdentityKey(ourIdentityKey.getPublicKey());
Pair<RootKey, ChainKey> sendingChain = calculate3DHE(false, ourBaseKey, theirBaseKey,
ourIdentityKey, theirIdentityKey);
sessionRecord.setSenderChain(ourEphemeralKey, sendingChain.second);
sessionRecord.setRootKey(sendingChain.first);
sessionState.setSenderChain(ourEphemeralKey, sendingChain.second);
sessionState.setRootKey(sendingChain.first);
}
private static Pair<RootKey, ChainKey> calculate3DHE(boolean isAlice,