Collapse SessionRecord, SessionState, and PreKeyRecord interfaces.

This commit is contained in:
Moxie Marlinspike
2014-04-24 15:39:55 -07:00
parent 5a3c19fe3e
commit a601c56af1
25 changed files with 1271 additions and 1836 deletions

View File

@@ -28,7 +28,6 @@ import org.whispersystems.libaxolotl.InvalidKeyIdException;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import org.whispersystems.libaxolotl.util.Medium;
import org.whispersystems.textsecure.storage.TextSecurePreKeyRecord;
import org.whispersystems.textsecure.storage.TextSecurePreKeyStore;
import org.whispersystems.textsecure.util.Util;
@@ -53,7 +52,7 @@ public class PreKeyUtil {
for (int i=0;i<BATCH_SIZE;i++) {
int preKeyId = (preKeyIdOffset + i) % Medium.MAX_VALUE;
ECKeyPair keyPair = Curve25519.generateKeyPair(true);
PreKeyRecord record = new TextSecurePreKeyRecord(masterSecret, preKeyId, keyPair);
PreKeyRecord record = new PreKeyRecord(preKeyId, keyPair);
preKeyStore.store(preKeyId, record);
records.add(record);
@@ -76,7 +75,7 @@ public class PreKeyUtil {
}
ECKeyPair keyPair = Curve25519.generateKeyPair(true);
PreKeyRecord record = new TextSecurePreKeyRecord(masterSecret, Medium.MAX_VALUE, keyPair);
PreKeyRecord record = new PreKeyRecord(Medium.MAX_VALUE, keyPair);
preKeyStore.store(Medium.MAX_VALUE, record);

View File

@@ -1,117 +0,0 @@
package org.whispersystems.textsecure.storage;
import android.util.Log;
import com.google.protobuf.ByteString;
import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.InvalidMessageException;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.textsecure.crypto.MasterCipher;
import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.util.Conversions;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class TextSecurePreKeyRecord implements PreKeyRecord {
private static final int CURRENT_VERSION_MARKER = 1;
private final MasterSecret masterSecret;
private StorageProtos.PreKeyRecordStructure structure;
public TextSecurePreKeyRecord(MasterSecret masterSecret, int id, ECKeyPair keyPair) {
this.masterSecret = masterSecret;
this.structure = StorageProtos.PreKeyRecordStructure.newBuilder()
.setId(id)
.setPublicKey(ByteString.copyFrom(keyPair.getPublicKey()
.serialize()))
.setPrivateKey(ByteString.copyFrom(keyPair.getPrivateKey()
.serialize()))
.build();
}
public TextSecurePreKeyRecord(MasterSecret masterSecret, FileInputStream in)
throws IOException, InvalidMessageException
{
this.masterSecret = masterSecret;
MasterCipher masterCipher = new MasterCipher(masterSecret);
int recordVersion = readInteger(in);
if (recordVersion != CURRENT_VERSION_MARKER) {
Log.w("PreKeyRecord", "Invalid version: " + recordVersion);
return;
}
this.structure =
StorageProtos.PreKeyRecordStructure.parseFrom(masterCipher.decryptBytes(readBlob(in)));
in.close();
}
@Override
public int getId() {
return this.structure.getId();
}
@Override
public ECKeyPair getKeyPair() {
try {
ECPublicKey publicKey = Curve.decodePoint(this.structure.getPublicKey().toByteArray(), 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(this.structure.getPrivateKey().toByteArray());
return new ECKeyPair(publicKey, privateKey);
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
public byte[] serialize() {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
MasterCipher masterCipher = new MasterCipher(masterSecret);
writeInteger(CURRENT_VERSION_MARKER, out);
writeBlob(masterCipher.encryptBytes(structure.toByteArray()), out);
return out.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
}
private byte[] readBlob(FileInputStream in) throws IOException {
int length = readInteger(in);
byte[] blobBytes = new byte[length];
in.read(blobBytes, 0, blobBytes.length);
return blobBytes;
}
private void writeBlob(byte[] blobBytes, OutputStream out) throws IOException {
writeInteger(blobBytes.length, out);
out.write(blobBytes);
}
private int readInteger(FileInputStream in) throws IOException {
byte[] integer = new byte[4];
in.read(integer, 0, integer.length);
return Conversions.byteArrayToInt(integer);
}
private void writeInteger(int value, OutputStream out) throws IOException {
byte[] valueBytes = Conversions.intToByteArray(value);
out.write(valueBytes);
}
}

View File

@@ -7,7 +7,9 @@ import org.whispersystems.libaxolotl.InvalidKeyIdException;
import org.whispersystems.libaxolotl.InvalidMessageException;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import org.whispersystems.textsecure.crypto.MasterCipher;
import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.util.Conversions;
import java.io.File;
import java.io.FileInputStream;
@@ -18,8 +20,9 @@ import java.nio.channels.FileChannel;
public class TextSecurePreKeyStore implements PreKeyStore {
public static final String PREKEY_DIRECTORY = "prekeys";
private static final String TAG = TextSecurePreKeyStore.class.getSimpleName();
public static final String PREKEY_DIRECTORY = "prekeys";
private static final int CURRENT_VERSION_MARKER = 1;
private static final String TAG = TextSecurePreKeyStore.class.getSimpleName();
private final Context context;
private final MasterSecret masterSecret;
@@ -32,8 +35,17 @@ public class TextSecurePreKeyStore implements PreKeyStore {
@Override
public PreKeyRecord load(int preKeyId) throws InvalidKeyIdException {
try {
FileInputStream fin = new FileInputStream(getPreKeyFile(preKeyId));
return new TextSecurePreKeyRecord(masterSecret, fin);
MasterCipher masterCipher = new MasterCipher(masterSecret);
FileInputStream fin = new FileInputStream(getPreKeyFile(preKeyId));
int recordVersion = readInteger(fin);
if (recordVersion != CURRENT_VERSION_MARKER) {
throw new AssertionError("Invalid version: " + recordVersion);
}
byte[] serializedRecord = masterCipher.decryptBytes(readBlob(fin));
return new PreKeyRecord(serializedRecord);
} catch (IOException | InvalidMessageException e) {
Log.w(TAG, e);
throw new InvalidKeyIdException(e);
@@ -43,11 +55,13 @@ public class TextSecurePreKeyStore implements PreKeyStore {
@Override
public void store(int preKeyId, PreKeyRecord record) {
try {
RandomAccessFile recordFile = new RandomAccessFile(getPreKeyFile(preKeyId), "rw");
FileChannel out = recordFile.getChannel();
MasterCipher masterCipher = new MasterCipher(masterSecret);
RandomAccessFile recordFile = new RandomAccessFile(getPreKeyFile(preKeyId), "rw");
FileChannel out = recordFile.getChannel();
out.position(0);
out.write(ByteBuffer.wrap(record.serialize()));
writeInteger(CURRENT_VERSION_MARKER, out);
writeBlob(masterCipher.encryptBytes(record.serialize()), out);
out.truncate(out.position());
recordFile.close();
@@ -83,4 +97,29 @@ public class TextSecurePreKeyStore implements PreKeyStore {
return directory;
}
private byte[] readBlob(FileInputStream in) throws IOException {
int length = readInteger(in);
byte[] blobBytes = new byte[length];
in.read(blobBytes, 0, blobBytes.length);
return blobBytes;
}
private void writeBlob(byte[] blobBytes, FileChannel out) throws IOException {
writeInteger(blobBytes.length, out);
out.write(ByteBuffer.wrap(blobBytes));
}
private int readInteger(FileInputStream in) throws IOException {
byte[] integer = new byte[4];
in.read(integer, 0, integer.length);
return Conversions.byteArrayToInt(integer);
}
private void writeInteger(int value, FileChannel out) throws IOException {
byte[] valueBytes = Conversions.intToByteArray(value);
out.write(ByteBuffer.wrap(valueBytes));
}
}

View File

@@ -1,142 +0,0 @@
package org.whispersystems.textsecure.storage;
import org.whispersystems.libaxolotl.InvalidMessageException;
import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.textsecure.crypto.MasterCipher;
import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.util.Conversions;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.LinkedList;
import java.util.List;
import static org.whispersystems.textsecure.storage.StorageProtos.RecordStructure;
import static org.whispersystems.textsecure.storage.StorageProtos.SessionStructure;
public class TextSecureSessionRecord implements SessionRecord {
private static final int SINGLE_STATE_VERSION = 1;
private static final int ARCHIVE_STATES_VERSION = 2;
private static final int CURRENT_VERSION = 2;
private TextSecureSessionState sessionState = new TextSecureSessionState(SessionStructure.newBuilder().build());
private List<SessionState> previousStates = new LinkedList<>();
private final MasterSecret masterSecret;
public TextSecureSessionRecord(MasterSecret masterSecret) {
this.masterSecret = masterSecret;
}
public TextSecureSessionRecord(MasterSecret masterSecret, FileInputStream in)
throws IOException, InvalidMessageException
{
this.masterSecret = masterSecret;
int versionMarker = readInteger(in);
if (versionMarker > CURRENT_VERSION) {
throw new AssertionError("Unknown version: " + versionMarker);
}
MasterCipher cipher = new MasterCipher(masterSecret);
byte[] encryptedBlob = readBlob(in);
if (versionMarker == SINGLE_STATE_VERSION) {
byte[] plaintextBytes = cipher.decryptBytes(encryptedBlob);
SessionStructure sessionStructure = SessionStructure.parseFrom(plaintextBytes);
this.sessionState = new TextSecureSessionState(sessionStructure);
} else if (versionMarker == ARCHIVE_STATES_VERSION) {
byte[] plaintextBytes = cipher.decryptBytes(encryptedBlob);
RecordStructure recordStructure = RecordStructure.parseFrom(plaintextBytes);
this.sessionState = new TextSecureSessionState(recordStructure.getCurrentSession());
this.previousStates = new LinkedList<>();
for (SessionStructure sessionStructure : recordStructure.getPreviousSessionsList()) {
this.previousStates.add(new TextSecureSessionState(sessionStructure));
}
} else {
throw new AssertionError("Unknown version: " + versionMarker);
}
in.close();
}
@Override
public SessionState getSessionState() {
return sessionState;
}
@Override
public List<SessionState> getPreviousSessionStates() {
return previousStates;
}
@Override
public void reset() {
this.sessionState = new TextSecureSessionState(SessionStructure.newBuilder().build());
this.previousStates = new LinkedList<>();
}
@Override
public void archiveCurrentState() {
this.previousStates.add(sessionState);
this.sessionState = new TextSecureSessionState(SessionStructure.newBuilder().build());
}
@Override
public byte[] serialize() {
try {
List<SessionStructure> previousStructures = new LinkedList<>();
for (SessionState previousState : previousStates) {
previousStructures.add(((TextSecureSessionState)previousState).getStructure());
}
RecordStructure record = RecordStructure.newBuilder()
.setCurrentSession(sessionState.getStructure())
.addAllPreviousSessions(previousStructures)
.build();
ByteArrayOutputStream serialized = new ByteArrayOutputStream();
MasterCipher cipher = new MasterCipher(masterSecret);
writeInteger(CURRENT_VERSION, serialized);
writeBlob(cipher.encryptBytes(record.toByteArray()), serialized);
return serialized.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
}
private byte[] readBlob(FileInputStream in) throws IOException {
int length = readInteger(in);
byte[] blobBytes = new byte[length];
in.read(blobBytes, 0, blobBytes.length);
return blobBytes;
}
private void writeBlob(byte[] blobBytes, OutputStream out) throws IOException {
writeInteger(blobBytes.length, out);
out.write(blobBytes);
}
private int readInteger(FileInputStream in) throws IOException {
byte[] integer = new byte[4];
in.read(integer, 0, integer.length);
return Conversions.byteArrayToInt(integer);
}
private void writeInteger(int value, OutputStream out) throws IOException {
byte[] valueBytes = Conversions.intToByteArray(value);
out.write(valueBytes);
}
}

View File

@@ -1,449 +0,0 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecure.storage;
import android.util.Log;
import com.google.protobuf.ByteString;
import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.IdentityKeyPair;
import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.ratchet.ChainKey;
import org.whispersystems.libaxolotl.ratchet.MessageKeys;
import org.whispersystems.libaxolotl.ratchet.RootKey;
import org.whispersystems.libaxolotl.util.Pair;
import org.whispersystems.textsecure.storage.StorageProtos.SessionStructure.Chain;
import org.whispersystems.textsecure.storage.StorageProtos.SessionStructure.PendingKeyExchange;
import org.whispersystems.textsecure.storage.StorageProtos.SessionStructure.PendingPreKey;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import javax.crypto.spec.SecretKeySpec;
import static org.whispersystems.textsecure.storage.StorageProtos.SessionStructure;
public class TextSecureSessionState implements SessionState {
private SessionStructure sessionStructure;
public TextSecureSessionState(SessionStructure sessionStructure) {
this.sessionStructure = sessionStructure;
}
public SessionStructure getStructure() {
return sessionStructure;
}
public void setNeedsRefresh(boolean needsRefresh) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setNeedsRefresh(needsRefresh)
.build();
}
public boolean getNeedsRefresh() {
return this.sessionStructure.getNeedsRefresh();
}
public void setSessionVersion(int version) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setSessionVersion(version)
.build();
}
public int getSessionVersion() {
return this.sessionStructure.getSessionVersion();
}
public void setRemoteIdentityKey(IdentityKey identityKey) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setRemoteIdentityPublic(ByteString.copyFrom(identityKey.serialize()))
.build();
}
public void setLocalIdentityKey(IdentityKey identityKey) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setLocalIdentityPublic(ByteString.copyFrom(identityKey.serialize()))
.build();
}
public IdentityKey getRemoteIdentityKey() {
try {
if (!this.sessionStructure.hasRemoteIdentityPublic()) {
return null;
}
return new IdentityKey(this.sessionStructure.getRemoteIdentityPublic().toByteArray(), 0);
} catch (InvalidKeyException e) {
Log.w("SessionRecordV2", e);
return null;
}
}
public IdentityKey getLocalIdentityKey() {
try {
return new IdentityKey(this.sessionStructure.getLocalIdentityPublic().toByteArray(), 0);
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
public int getPreviousCounter() {
return sessionStructure.getPreviousCounter();
}
public void setPreviousCounter(int previousCounter) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setPreviousCounter(previousCounter)
.build();
}
public RootKey getRootKey() {
return new RootKey(this.sessionStructure.getRootKey().toByteArray());
}
public void setRootKey(RootKey rootKey) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setRootKey(ByteString.copyFrom(rootKey.getKeyBytes()))
.build();
}
public ECPublicKey getSenderEphemeral() {
try {
return Curve.decodePoint(sessionStructure.getSenderChain().getSenderEphemeral().toByteArray(), 0);
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
public ECKeyPair getSenderEphemeralPair() {
ECPublicKey publicKey = getSenderEphemeral();
ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getSenderChain()
.getSenderEphemeralPrivate()
.toByteArray());
return new ECKeyPair(publicKey, privateKey);
}
public boolean hasReceiverChain(ECPublicKey senderEphemeral) {
return getReceiverChain(senderEphemeral) != null;
}
public boolean hasSenderChain() {
return sessionStructure.hasSenderChain();
}
private Pair<Chain,Integer> getReceiverChain(ECPublicKey senderEphemeral) {
List<Chain> receiverChains = sessionStructure.getReceiverChainsList();
int index = 0;
for (Chain receiverChain : receiverChains) {
try {
ECPublicKey chainSenderEphemeral = Curve.decodePoint(receiverChain.getSenderEphemeral().toByteArray(), 0);
if (chainSenderEphemeral.equals(senderEphemeral)) {
return new Pair<Chain,Integer>(receiverChain,index);
}
} catch (InvalidKeyException e) {
Log.w("SessionRecordV2", e);
}
index++;
}
return null;
}
public ChainKey getReceiverChainKey(ECPublicKey senderEphemeral) {
Pair<Chain,Integer> receiverChainAndIndex = getReceiverChain(senderEphemeral);
Chain receiverChain = receiverChainAndIndex.first();
if (receiverChain == null) {
return null;
} else {
return new ChainKey(receiverChain.getChainKey().getKey().toByteArray(),
receiverChain.getChainKey().getIndex());
}
}
public void addReceiverChain(ECPublicKey senderEphemeral, ChainKey chainKey) {
Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder()
.setKey(ByteString.copyFrom(chainKey.getKey()))
.setIndex(chainKey.getIndex())
.build();
Chain chain = Chain.newBuilder()
.setChainKey(chainKeyStructure)
.setSenderEphemeral(ByteString.copyFrom(senderEphemeral.serialize()))
.build();
this.sessionStructure = this.sessionStructure.toBuilder().addReceiverChains(chain).build();
if (this.sessionStructure.getReceiverChainsList().size() > 5) {
this.sessionStructure = this.sessionStructure.toBuilder()
.removeReceiverChains(0)
.build();
}
}
public void setSenderChain(ECKeyPair senderEphemeralPair, ChainKey chainKey) {
Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder()
.setKey(ByteString.copyFrom(chainKey.getKey()))
.setIndex(chainKey.getIndex())
.build();
Chain senderChain = Chain.newBuilder()
.setSenderEphemeral(ByteString.copyFrom(senderEphemeralPair.getPublicKey().serialize()))
.setSenderEphemeralPrivate(ByteString.copyFrom(senderEphemeralPair.getPrivateKey().serialize()))
.setChainKey(chainKeyStructure)
.build();
this.sessionStructure = this.sessionStructure.toBuilder().setSenderChain(senderChain).build();
}
public ChainKey getSenderChainKey() {
Chain.ChainKey chainKeyStructure = sessionStructure.getSenderChain().getChainKey();
return new ChainKey(chainKeyStructure.getKey().toByteArray(), chainKeyStructure.getIndex());
}
public void setSenderChainKey(ChainKey nextChainKey) {
Chain.ChainKey chainKey = Chain.ChainKey.newBuilder()
.setKey(ByteString.copyFrom(nextChainKey.getKey()))
.setIndex(nextChainKey.getIndex())
.build();
Chain chain = sessionStructure.getSenderChain().toBuilder()
.setChainKey(chainKey).build();
this.sessionStructure = this.sessionStructure.toBuilder().setSenderChain(chain).build();
}
public boolean hasMessageKeys(ECPublicKey senderEphemeral, int counter) {
Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral);
Chain chain = chainAndIndex.first();
if (chain == null) {
return false;
}
List<Chain.MessageKey> messageKeyList = chain.getMessageKeysList();
for (Chain.MessageKey messageKey : messageKeyList) {
if (messageKey.getIndex() == counter) {
return true;
}
}
return false;
}
public MessageKeys removeMessageKeys(ECPublicKey senderEphemeral, int counter) {
Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral);
Chain chain = chainAndIndex.first();
if (chain == null) {
return null;
}
List<Chain.MessageKey> messageKeyList = new LinkedList<Chain.MessageKey>(chain.getMessageKeysList());
Iterator<Chain.MessageKey> messageKeyIterator = messageKeyList.iterator();
MessageKeys result = null;
while (messageKeyIterator.hasNext()) {
Chain.MessageKey messageKey = messageKeyIterator.next();
if (messageKey.getIndex() == counter) {
result = new MessageKeys(new SecretKeySpec(messageKey.getCipherKey().toByteArray(), "AES"),
new SecretKeySpec(messageKey.getMacKey().toByteArray(), "HmacSHA256"),
messageKey.getIndex());
messageKeyIterator.remove();
break;
}
}
Chain updatedChain = chain.toBuilder().clearMessageKeys()
.addAllMessageKeys(messageKeyList)
.build();
this.sessionStructure = this.sessionStructure.toBuilder()
.setReceiverChains(chainAndIndex.second(), updatedChain)
.build();
return result;
}
public void setMessageKeys(ECPublicKey senderEphemeral, MessageKeys messageKeys) {
Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral);
Chain chain = chainAndIndex.first();
Chain.MessageKey messageKeyStructure = Chain.MessageKey.newBuilder()
.setCipherKey(ByteString.copyFrom(messageKeys.getCipherKey().getEncoded()))
.setMacKey(ByteString.copyFrom(messageKeys.getMacKey().getEncoded()))
.setIndex(messageKeys.getCounter())
.build();
Chain updatedChain = chain.toBuilder()
.addMessageKeys(messageKeyStructure)
.build();
this.sessionStructure = this.sessionStructure.toBuilder()
.setReceiverChains(chainAndIndex.second(), updatedChain)
.build();
}
public void setReceiverChainKey(ECPublicKey senderEphemeral, ChainKey chainKey) {
Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral);
Chain chain = chainAndIndex.first();
Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder()
.setKey(ByteString.copyFrom(chainKey.getKey()))
.setIndex(chainKey.getIndex())
.build();
Chain updatedChain = chain.toBuilder().setChainKey(chainKeyStructure).build();
this.sessionStructure = this.sessionStructure.toBuilder()
.setReceiverChains(chainAndIndex.second(), updatedChain)
.build();
}
public void setPendingKeyExchange(int sequence,
ECKeyPair ourBaseKey,
ECKeyPair ourEphemeralKey,
IdentityKeyPair ourIdentityKey)
{
PendingKeyExchange structure =
PendingKeyExchange.newBuilder()
.setSequence(sequence)
.setLocalBaseKey(ByteString.copyFrom(ourBaseKey.getPublicKey().serialize()))
.setLocalBaseKeyPrivate(ByteString.copyFrom(ourBaseKey.getPrivateKey().serialize()))
.setLocalEphemeralKey(ByteString.copyFrom(ourEphemeralKey.getPublicKey().serialize()))
.setLocalEphemeralKeyPrivate(ByteString.copyFrom(ourEphemeralKey.getPrivateKey().serialize()))
.setLocalIdentityKey(ByteString.copyFrom(ourIdentityKey.getPublicKey().serialize()))
.setLocalIdentityKeyPrivate(ByteString.copyFrom(ourIdentityKey.getPrivateKey().serialize()))
.build();
this.sessionStructure = this.sessionStructure.toBuilder()
.setPendingKeyExchange(structure)
.build();
}
public int getPendingKeyExchangeSequence() {
return sessionStructure.getPendingKeyExchange().getSequence();
}
public ECKeyPair getPendingKeyExchangeBaseKey() throws InvalidKeyException {
ECPublicKey publicKey = Curve.decodePoint(sessionStructure.getPendingKeyExchange()
.getLocalBaseKey().toByteArray(), 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange()
.getLocalBaseKeyPrivate()
.toByteArray());
return new ECKeyPair(publicKey, privateKey);
}
public ECKeyPair getPendingKeyExchangeEphemeralKey() throws InvalidKeyException {
ECPublicKey publicKey = Curve.decodePoint(sessionStructure.getPendingKeyExchange()
.getLocalEphemeralKey().toByteArray(), 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange()
.getLocalEphemeralKeyPrivate()
.toByteArray());
return new ECKeyPair(publicKey, privateKey);
}
public IdentityKeyPair getPendingKeyExchangeIdentityKey() throws InvalidKeyException {
IdentityKey publicKey = new IdentityKey(sessionStructure.getPendingKeyExchange()
.getLocalIdentityKey().toByteArray(), 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange()
.getLocalIdentityKeyPrivate()
.toByteArray());
return new IdentityKeyPair(publicKey, privateKey);
}
public boolean hasPendingKeyExchange() {
return sessionStructure.hasPendingKeyExchange();
}
public void setPendingPreKey(int preKeyId, ECPublicKey baseKey) {
PendingPreKey pending = PendingPreKey.newBuilder()
.setPreKeyId(preKeyId)
.setBaseKey(ByteString.copyFrom(baseKey.serialize()))
.build();
this.sessionStructure = this.sessionStructure.toBuilder()
.setPendingPreKey(pending)
.build();
}
public boolean hasPendingPreKey() {
return this.sessionStructure.hasPendingPreKey();
}
public Pair<Integer, ECPublicKey> getPendingPreKey() {
try {
return new Pair<Integer, ECPublicKey>(sessionStructure.getPendingPreKey().getPreKeyId(),
Curve.decodePoint(sessionStructure.getPendingPreKey()
.getBaseKey()
.toByteArray(), 0));
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
public void clearPendingPreKey() {
this.sessionStructure = this.sessionStructure.toBuilder()
.clearPendingPreKey()
.build();
}
public void setRemoteRegistrationId(int registrationId) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setRemoteRegistrationId(registrationId)
.build();
}
public int getRemoteRegistrationId() {
return this.sessionStructure.getRemoteRegistrationId();
}
public void setLocalRegistrationId(int registrationId) {
this.sessionStructure = this.sessionStructure.toBuilder()
.setLocalRegistrationId(registrationId)
.build();
}
public int getLocalRegistrationId() {
return this.sessionStructure.getLocalRegistrationId();
}
public byte[] serialize() {
return sessionStructure.toByteArray();
}
}

View File

@@ -5,8 +5,11 @@ import android.util.Log;
import org.whispersystems.libaxolotl.InvalidMessageException;
import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.libaxolotl.state.SessionStore;
import org.whispersystems.textsecure.crypto.MasterCipher;
import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.util.Conversions;
import java.io.File;
import java.io.FileInputStream;
@@ -17,12 +20,18 @@ import java.nio.channels.FileChannel;
import java.util.LinkedList;
import java.util.List;
import static org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure;
public class TextSecureSessionStore implements SessionStore {
private static final String TAG = TextSecureSessionStore.class.getSimpleName();
private static final String SESSIONS_DIRECTORY_V2 = "sessions-v2";
private static final Object FILE_LOCK = new Object();
private static final int SINGLE_STATE_VERSION = 1;
private static final int ARCHIVE_STATES_VERSION = 2;
private static final int CURRENT_VERSION = 2;
private final Context context;
private final MasterSecret masterSecret;
@@ -35,11 +44,30 @@ public class TextSecureSessionStore implements SessionStore {
public SessionRecord load(long recipientId, int deviceId) {
synchronized (FILE_LOCK) {
try {
FileInputStream input = new FileInputStream(getSessionFile(recipientId, deviceId));
return new TextSecureSessionRecord(masterSecret, input);
MasterCipher cipher = new MasterCipher(masterSecret);
FileInputStream in = new FileInputStream(getSessionFile(recipientId, deviceId));
int versionMarker = readInteger(in);
if (versionMarker > CURRENT_VERSION) {
throw new AssertionError("Unknown version: " + versionMarker);
}
byte[] serialized = cipher.decryptBytes(readBlob(in));
in.close();
if (versionMarker == SINGLE_STATE_VERSION) {
SessionStructure sessionStructure = SessionStructure.parseFrom(serialized);
SessionState sessionState = new SessionState(sessionStructure);
return new SessionRecord(sessionState);
} else if (versionMarker == ARCHIVE_STATES_VERSION) {
return new SessionRecord(serialized);
} else {
throw new AssertionError("Unknown version: " + versionMarker);
}
} catch (InvalidMessageException | IOException e) {
Log.w(TAG, "No existing session information found.");
return new TextSecureSessionRecord(masterSecret);
return new SessionRecord();
}
}
}
@@ -47,11 +75,13 @@ public class TextSecureSessionStore implements SessionStore {
@Override
public void store(long recipientId, int deviceId, SessionRecord record) {
try {
MasterCipher masterCipher = new MasterCipher(masterSecret);
RandomAccessFile sessionFile = new RandomAccessFile(getSessionFile(recipientId, deviceId), "rw");
FileChannel out = sessionFile.getChannel();
out.position(0);
out.write(ByteBuffer.wrap(record.serialize()));
writeInteger(CURRENT_VERSION, out);
writeBlob(masterCipher.encryptBytes(record.serialize()), out);
out.truncate(out.position());
sessionFile.close();
@@ -126,4 +156,28 @@ public class TextSecureSessionStore implements SessionStore {
return recipientId + (deviceId == RecipientDevice.DEFAULT_DEVICE_ID ? "" : "." + deviceId);
}
private byte[] readBlob(FileInputStream in) throws IOException {
int length = readInteger(in);
byte[] blobBytes = new byte[length];
in.read(blobBytes, 0, blobBytes.length);
return blobBytes;
}
private void writeBlob(byte[] blobBytes, FileChannel out) throws IOException {
writeInteger(blobBytes.length, out);
out.write(ByteBuffer.wrap(blobBytes));
}
private int readInteger(FileInputStream in) throws IOException {
byte[] integer = new byte[4];
in.read(integer, 0, integer.length);
return Conversions.byteArrayToInt(integer);
}
private void writeInteger(int value, FileChannel out) throws IOException {
byte[] valueBytes = Conversions.intToByteArray(value);
out.write(ByteBuffer.wrap(valueBytes));
}
}

View File

@@ -93,15 +93,6 @@ public class Util {
return value == null || value.length() == 0;
}
public static int generateRegistrationId() {
try {
return SecureRandom.getInstance("SHA1PRNG").nextInt(16380) + 1;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
public static String getSecret(int size) {
try {
byte[] secret = new byte[size];