mirror of
https://github.com/signalapp/Signal-Android.git
synced 2026-04-26 19:56:02 +01:00
Handle simultaneous initiate protocol case.
1) Modify SessionRecord to store a list of "previous" sessions in addition to the current active session. Previous sessions can be used for receiving messages, but not for sending messages. 2) When a possible "simultaneous initiate" is detected, push the current session onto the "previous session" stack instead of clearing it and starting over. 3) Additionally, mark the new session created on a received possible "simultaneous initiate" as stale for sending. The next outgoing message would trigger a full prekey refresh. 4) Work to do: outgoing messages on the SMS transport should probably not use the existing session if it's marked stale for sending. These messages need to fail and notify the user, similar to how we'll handle SMS fallback to push users before a prekey session is created.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package org.whispersystems.textsecure.crypto;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import android.util.Pair;
|
||||
|
||||
import org.whispersystems.textsecure.crypto.ecc.Curve;
|
||||
@@ -14,10 +15,12 @@ import org.whispersystems.textsecure.crypto.ratchet.MessageKeys;
|
||||
import org.whispersystems.textsecure.crypto.ratchet.RootKey;
|
||||
import org.whispersystems.textsecure.storage.RecipientDevice;
|
||||
import org.whispersystems.textsecure.storage.SessionRecordV2;
|
||||
import org.whispersystems.textsecure.storage.SessionState;
|
||||
import org.whispersystems.textsecure.util.Conversions;
|
||||
|
||||
import java.security.InvalidAlgorithmParameterException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.List;
|
||||
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.Cipher;
|
||||
@@ -45,27 +48,28 @@ public class SessionCipherV2 extends SessionCipher {
|
||||
public CiphertextMessage encrypt(byte[] paddedMessage) {
|
||||
synchronized (SESSION_LOCK) {
|
||||
SessionRecordV2 sessionRecord = getSessionRecord();
|
||||
ChainKey chainKey = sessionRecord.getSenderChainKey();
|
||||
SessionState sessionState = sessionRecord.getSessionState();
|
||||
ChainKey chainKey = sessionState.getSenderChainKey();
|
||||
MessageKeys messageKeys = chainKey.getMessageKeys();
|
||||
ECPublicKey senderEphemeral = sessionRecord.getSenderEphemeral();
|
||||
int previousCounter = sessionRecord.getPreviousCounter();
|
||||
ECPublicKey senderEphemeral = sessionState.getSenderEphemeral();
|
||||
int previousCounter = sessionState.getPreviousCounter();
|
||||
|
||||
byte[] ciphertextBody = getCiphertext(messageKeys, paddedMessage);
|
||||
CiphertextMessage ciphertextMessage = new WhisperMessageV2(messageKeys.getMacKey(),
|
||||
senderEphemeral, chainKey.getIndex(),
|
||||
previousCounter, ciphertextBody);
|
||||
|
||||
if (sessionRecord.hasPendingPreKey()) {
|
||||
Pair<Integer, ECPublicKey> pendingPreKey = sessionRecord.getPendingPreKey();
|
||||
int localRegistrationId = sessionRecord.getLocalRegistrationId();
|
||||
if (sessionState.hasPendingPreKey()) {
|
||||
Pair<Integer, ECPublicKey> pendingPreKey = sessionState.getPendingPreKey();
|
||||
int localRegistrationId = sessionState.getLocalRegistrationId();
|
||||
|
||||
ciphertextMessage = new PreKeyWhisperMessage(localRegistrationId, pendingPreKey.first,
|
||||
pendingPreKey.second,
|
||||
sessionRecord.getLocalIdentityKey(),
|
||||
sessionState.getLocalIdentityKey(),
|
||||
(WhisperMessageV2) ciphertextMessage);
|
||||
}
|
||||
|
||||
sessionRecord.setSenderChainKey(chainKey.getNextChainKey());
|
||||
sessionState.setSenderChainKey(chainKey.getNextChainKey());
|
||||
sessionRecord.save();
|
||||
|
||||
return ciphertextMessage;
|
||||
@@ -75,50 +79,83 @@ public class SessionCipherV2 extends SessionCipher {
|
||||
@Override
|
||||
public byte[] decrypt(byte[] decodedMessage) throws InvalidMessageException {
|
||||
synchronized (SESSION_LOCK) {
|
||||
SessionRecordV2 sessionRecord = getSessionRecord();
|
||||
WhisperMessageV2 ciphertextMessage = new WhisperMessageV2(decodedMessage);
|
||||
ECPublicKey theirEphemeral = ciphertextMessage.getSenderEphemeral();
|
||||
int counter = ciphertextMessage.getCounter();
|
||||
ChainKey chainKey = getOrCreateChainKey(sessionRecord, theirEphemeral);
|
||||
MessageKeys messageKeys = getOrCreateMessageKeys(sessionRecord, theirEphemeral,
|
||||
chainKey, counter);
|
||||
SessionRecordV2 sessionRecord = getSessionRecord();
|
||||
SessionState sessionState = sessionRecord.getSessionState();
|
||||
List<SessionState> previousStates = sessionRecord.getPreviousSessions();
|
||||
|
||||
ciphertextMessage.verifyMac(messageKeys.getMacKey());
|
||||
try {
|
||||
byte[] plaintext = decrypt(sessionState, decodedMessage);
|
||||
sessionRecord.save();
|
||||
|
||||
byte[] plaintext = getPlaintext(messageKeys, ciphertextMessage.getBody());
|
||||
return plaintext;
|
||||
} catch (InvalidMessageException e) {
|
||||
Log.w("SessionCipherV2", e);
|
||||
}
|
||||
|
||||
sessionRecord.clearPendingPreKey();
|
||||
sessionRecord.save();
|
||||
for (SessionState previousState : previousStates) {
|
||||
try {
|
||||
byte[] plaintext = decrypt(previousState, decodedMessage);
|
||||
sessionRecord.save();
|
||||
|
||||
return plaintext;
|
||||
return plaintext;
|
||||
} catch (InvalidMessageException e) {
|
||||
Log.w("SessionCipherV2", e);
|
||||
}
|
||||
}
|
||||
|
||||
throw new InvalidMessageException("No valid sessions.");
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] decrypt(SessionState sessionState, byte[] decodedMessage)
|
||||
throws InvalidMessageException
|
||||
{
|
||||
if (!sessionState.hasSenderChain()) {
|
||||
throw new InvalidMessageException("Uninitialized session!");
|
||||
}
|
||||
|
||||
WhisperMessageV2 ciphertextMessage = new WhisperMessageV2(decodedMessage);
|
||||
ECPublicKey theirEphemeral = ciphertextMessage.getSenderEphemeral();
|
||||
int counter = ciphertextMessage.getCounter();
|
||||
ChainKey chainKey = getOrCreateChainKey(sessionState, theirEphemeral);
|
||||
MessageKeys messageKeys = getOrCreateMessageKeys(sessionState, theirEphemeral,
|
||||
chainKey, counter);
|
||||
|
||||
ciphertextMessage.verifyMac(messageKeys.getMacKey());
|
||||
|
||||
byte[] plaintext = getPlaintext(messageKeys, ciphertextMessage.getBody());
|
||||
|
||||
sessionState.clearPendingPreKey();
|
||||
|
||||
return plaintext;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getRemoteRegistrationId() {
|
||||
synchronized (SESSION_LOCK) {
|
||||
SessionRecordV2 sessionRecord = getSessionRecord();
|
||||
return sessionRecord.getRemoteRegistrationId();
|
||||
return sessionRecord.getSessionState().getRemoteRegistrationId();
|
||||
}
|
||||
}
|
||||
|
||||
private ChainKey getOrCreateChainKey(SessionRecordV2 sessionRecord, ECPublicKey theirEphemeral)
|
||||
private ChainKey getOrCreateChainKey(SessionState sessionState, ECPublicKey theirEphemeral)
|
||||
throws InvalidMessageException
|
||||
{
|
||||
try {
|
||||
if (sessionRecord.hasReceiverChain(theirEphemeral)) {
|
||||
return sessionRecord.getReceiverChainKey(theirEphemeral);
|
||||
if (sessionState.hasReceiverChain(theirEphemeral)) {
|
||||
return sessionState.getReceiverChainKey(theirEphemeral);
|
||||
} else {
|
||||
RootKey rootKey = sessionRecord.getRootKey();
|
||||
ECKeyPair ourEphemeral = sessionRecord.getSenderEphemeralPair();
|
||||
RootKey rootKey = sessionState.getRootKey();
|
||||
ECKeyPair ourEphemeral = sessionState.getSenderEphemeralPair();
|
||||
Pair<RootKey, ChainKey> receiverChain = rootKey.createChain(theirEphemeral, ourEphemeral);
|
||||
ECKeyPair ourNewEphemeral = Curve.generateKeyPairForType(Curve.DJB_TYPE);
|
||||
Pair<RootKey, ChainKey> senderChain = receiverChain.first.createChain(theirEphemeral, ourNewEphemeral);
|
||||
|
||||
sessionRecord.setRootKey(senderChain.first);
|
||||
sessionRecord.addReceiverChain(theirEphemeral, receiverChain.second);
|
||||
sessionRecord.setPreviousCounter(sessionRecord.getSenderChainKey().getIndex()-1);
|
||||
sessionRecord.setSenderChain(ourNewEphemeral, senderChain.second);
|
||||
sessionState.setRootKey(senderChain.first);
|
||||
sessionState.addReceiverChain(theirEphemeral, receiverChain.second);
|
||||
sessionState.setPreviousCounter(sessionState.getSenderChainKey().getIndex()-1);
|
||||
sessionState.setSenderChain(ourNewEphemeral, senderChain.second);
|
||||
|
||||
return receiverChain.second;
|
||||
}
|
||||
@@ -127,14 +164,14 @@ public class SessionCipherV2 extends SessionCipher {
|
||||
}
|
||||
}
|
||||
|
||||
private MessageKeys getOrCreateMessageKeys(SessionRecordV2 sessionRecord,
|
||||
private MessageKeys getOrCreateMessageKeys(SessionState sessionState,
|
||||
ECPublicKey theirEphemeral,
|
||||
ChainKey chainKey, int counter)
|
||||
throws InvalidMessageException
|
||||
{
|
||||
if (chainKey.getIndex() > counter) {
|
||||
if (sessionRecord.hasMessageKeys(theirEphemeral, counter)) {
|
||||
return sessionRecord.removeMessageKeys(theirEphemeral, counter);
|
||||
if (sessionState.hasMessageKeys(theirEphemeral, counter)) {
|
||||
return sessionState.removeMessageKeys(theirEphemeral, counter);
|
||||
} else {
|
||||
throw new InvalidMessageException("Received message with old counter!");
|
||||
}
|
||||
@@ -146,11 +183,11 @@ public class SessionCipherV2 extends SessionCipher {
|
||||
|
||||
while (chainKey.getIndex() < counter) {
|
||||
MessageKeys messageKeys = chainKey.getMessageKeys();
|
||||
sessionRecord.setMessageKeys(theirEphemeral, messageKeys);
|
||||
sessionState.setMessageKeys(theirEphemeral, messageKeys);
|
||||
chainKey = chainKey.getNextChainKey();
|
||||
}
|
||||
|
||||
sessionRecord.setReceiverChainKey(theirEphemeral, chainKey.getNextChainKey());
|
||||
sessionState.setReceiverChainKey(theirEphemeral, chainKey.getNextChainKey());
|
||||
return chainKey.getMessageKeys();
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ import org.whispersystems.textsecure.crypto.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecure.crypto.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets;
|
||||
import org.whispersystems.textsecure.crypto.kdf.HKDF;
|
||||
import org.whispersystems.textsecure.storage.SessionRecordV2;
|
||||
import org.whispersystems.textsecure.storage.SessionState;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
public class RatchetingSession {
|
||||
|
||||
public static void initializeSession(SessionRecordV2 sessionRecord,
|
||||
public static void initializeSession(SessionState sessionState,
|
||||
ECKeyPair ourBaseKey,
|
||||
ECPublicKey theirBaseKey,
|
||||
ECKeyPair ourEphemeralKey,
|
||||
@@ -27,48 +27,48 @@ public class RatchetingSession {
|
||||
throws InvalidKeyException
|
||||
{
|
||||
if (isAlice(ourBaseKey.getPublicKey(), theirBaseKey, ourEphemeralKey.getPublicKey(), theirEphemeralKey)) {
|
||||
initializeSessionAsAlice(sessionRecord, ourBaseKey, theirBaseKey, theirEphemeralKey,
|
||||
initializeSessionAsAlice(sessionState, ourBaseKey, theirBaseKey, theirEphemeralKey,
|
||||
ourIdentityKey, theirIdentityKey);
|
||||
} else {
|
||||
initializeSessionAsBob(sessionRecord, ourBaseKey, theirBaseKey,
|
||||
initializeSessionAsBob(sessionState, ourBaseKey, theirBaseKey,
|
||||
ourEphemeralKey, ourIdentityKey, theirIdentityKey);
|
||||
}
|
||||
}
|
||||
|
||||
private static void initializeSessionAsAlice(SessionRecordV2 sessionRecord,
|
||||
private static void initializeSessionAsAlice(SessionState sessionState,
|
||||
ECKeyPair ourBaseKey, ECPublicKey theirBaseKey,
|
||||
ECPublicKey theirEphemeralKey,
|
||||
IdentityKeyPair ourIdentityKey,
|
||||
IdentityKey theirIdentityKey)
|
||||
throws InvalidKeyException
|
||||
{
|
||||
sessionRecord.setRemoteIdentityKey(theirIdentityKey);
|
||||
sessionRecord.setLocalIdentityKey(ourIdentityKey.getPublicKey());
|
||||
sessionState.setRemoteIdentityKey(theirIdentityKey);
|
||||
sessionState.setLocalIdentityKey(ourIdentityKey.getPublicKey());
|
||||
|
||||
ECKeyPair sendingKey = Curve.generateKeyPairForType(ourIdentityKey.getPublicKey().getPublicKey().getType());
|
||||
Pair<RootKey, ChainKey> receivingChain = calculate3DHE(true, ourBaseKey, theirBaseKey, ourIdentityKey, theirIdentityKey);
|
||||
Pair<RootKey, ChainKey> sendingChain = receivingChain.first.createChain(theirEphemeralKey, sendingKey);
|
||||
|
||||
sessionRecord.addReceiverChain(theirEphemeralKey, receivingChain.second);
|
||||
sessionRecord.setSenderChain(sendingKey, sendingChain.second);
|
||||
sessionRecord.setRootKey(sendingChain.first);
|
||||
sessionState.addReceiverChain(theirEphemeralKey, receivingChain.second);
|
||||
sessionState.setSenderChain(sendingKey, sendingChain.second);
|
||||
sessionState.setRootKey(sendingChain.first);
|
||||
}
|
||||
|
||||
private static void initializeSessionAsBob(SessionRecordV2 sessionRecord,
|
||||
private static void initializeSessionAsBob(SessionState sessionState,
|
||||
ECKeyPair ourBaseKey, ECPublicKey theirBaseKey,
|
||||
ECKeyPair ourEphemeralKey,
|
||||
IdentityKeyPair ourIdentityKey,
|
||||
IdentityKey theirIdentityKey)
|
||||
throws InvalidKeyException
|
||||
{
|
||||
sessionRecord.setRemoteIdentityKey(theirIdentityKey);
|
||||
sessionRecord.setLocalIdentityKey(ourIdentityKey.getPublicKey());
|
||||
sessionState.setRemoteIdentityKey(theirIdentityKey);
|
||||
sessionState.setLocalIdentityKey(ourIdentityKey.getPublicKey());
|
||||
|
||||
Pair<RootKey, ChainKey> sendingChain = calculate3DHE(false, ourBaseKey, theirBaseKey,
|
||||
ourIdentityKey, theirIdentityKey);
|
||||
|
||||
sessionRecord.setSenderChain(ourEphemeralKey, sendingChain.second);
|
||||
sessionRecord.setRootKey(sendingChain.first);
|
||||
sessionState.setSenderChain(ourEphemeralKey, sendingChain.second);
|
||||
sessionState.setRootKey(sendingChain.first);
|
||||
}
|
||||
|
||||
private static Pair<RootKey, ChainKey> calculate3DHE(boolean isAlice,
|
||||
|
||||
Reference in New Issue
Block a user