Introduce encrypted device creation timestamps

This commit is contained in:
Katherine
2025-07-23 10:36:11 -04:00
committed by GitHub
parent 74c7e49cea
commit 96f6e75702
13 changed files with 181 additions and 14 deletions

View File

@@ -190,12 +190,16 @@ class DeviceControllerTest {
final byte[] deviceName = "refreshed-device-name".getBytes(StandardCharsets.UTF_8);
final long deviceCreated = System.currentTimeMillis();
final long deviceLastSeen = deviceCreated + 1;
final int registrationId = 2;
final byte[] createdAtCiphertext = "timestamp ciphertext".getBytes(StandardCharsets.UTF_8);
final Device refreshedDevice = mock(Device.class);
when(refreshedDevice.getId()).thenReturn(deviceId);
when(refreshedDevice.getName()).thenReturn(deviceName);
when(refreshedDevice.getCreated()).thenReturn(deviceCreated);
when(refreshedDevice.getLastSeen()).thenReturn(deviceLastSeen);
when(refreshedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
when(refreshedDevice.getCreatedAtCiphertext()).thenReturn(createdAtCiphertext);
final Account refreshedAccount = mock(Account.class);
when(refreshedAccount.getDevices()).thenReturn(List.of(refreshedDevice));
@@ -213,6 +217,8 @@ class DeviceControllerTest {
assertArrayEquals(deviceName, deviceInfoList.devices().getFirst().name());
assertEquals(deviceCreated, deviceInfoList.devices().getFirst().created());
assertEquals(deviceLastSeen, deviceInfoList.devices().getFirst().lastSeen());
assertEquals(registrationId, deviceInfoList.devices().getFirst().registrationId());
assertArrayEquals(createdAtCiphertext, deviceInfoList.devices().getFirst().createdAtCiphertext());
}
@ParameterizedTest
@@ -241,7 +247,8 @@ class DeviceControllerTest {
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
@@ -250,7 +257,7 @@ class DeviceControllerTest {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey)));
});
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
@@ -273,7 +280,7 @@ class DeviceControllerTest {
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any());
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey);
assertEquals(fetchesMessages, device.getFetchesMessages());
@@ -741,15 +748,16 @@ class DeviceControllerTest {
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey)));
});
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
@@ -953,7 +961,9 @@ class DeviceControllerTest {
final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID,
"Device name ciphertext".getBytes(StandardCharsets.UTF_8),
System.currentTimeMillis(),
System.currentTimeMillis());
System.currentTimeMillis(),
1,
"timestamp ciphertext".getBytes(StandardCharsets.UTF_8));
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
@@ -976,6 +986,8 @@ class DeviceControllerTest {
assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name());
assertEquals(deviceInfo.created(), retrievedDeviceInfo.created());
assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen());
assertEquals(deviceInfo.registrationId(), retrievedDeviceInfo.registrationId());
assertArrayEquals(deviceInfo.createdAtCiphertext(), retrievedDeviceInfo.createdAtCiphertext());
}
}

View File

@@ -48,6 +48,7 @@ import org.signal.chat.device.SetDeviceNameRequest;
import org.signal.chat.device.SetDeviceNameResponse;
import org.signal.chat.device.SetPushTokenRequest;
import org.signal.chat.device.SetPushTokenResponse;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -102,11 +103,17 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final Instant primaryDeviceLastSeen = primaryDeviceCreated.plus(Duration.ofHours(6));
final Instant linkedDeviceCreated = Instant.now().minus(Duration.ofDays(1)).truncatedTo(ChronoUnit.MILLIS);
final Instant linkedDeviceLastSeen = linkedDeviceCreated.plus(Duration.ofHours(7));
final int primaryRegistrationId = 1234;
final int linkedRegistrationId = 1235;
final byte[] primaryCreatedAtCiphertext = "primary_timestamp_ciphertext".getBytes(StandardCharsets.UTF_8);
final byte[] linkedCreatedAtCiphertext = "linked_timestamp_ciphertext".getBytes(StandardCharsets.UTF_8);
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(primaryDevice.getCreated()).thenReturn(primaryDeviceCreated.toEpochMilli());
when(primaryDevice.getLastSeen()).thenReturn(primaryDeviceLastSeen.toEpochMilli());
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryRegistrationId);
when(primaryDevice.getCreatedAtCiphertext()).thenReturn(primaryCreatedAtCiphertext);
final String linkedDeviceName = "A linked device";
@@ -115,6 +122,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
when(linkedDevice.getCreated()).thenReturn(linkedDeviceCreated.toEpochMilli());
when(linkedDevice.getLastSeen()).thenReturn(linkedDeviceLastSeen.toEpochMilli());
when(linkedDevice.getName()).thenReturn(linkedDeviceName.getBytes(StandardCharsets.UTF_8));
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedRegistrationId);
when(linkedDevice.getCreatedAtCiphertext()).thenReturn(linkedCreatedAtCiphertext);
when(authenticatedAccount.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
@@ -123,12 +132,16 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
.setId(Device.PRIMARY_ID)
.setCreated(primaryDeviceCreated.toEpochMilli())
.setLastSeen(primaryDeviceLastSeen.toEpochMilli())
.setRegistrationId(primaryRegistrationId)
.setCreatedAtCiphertext(ByteString.copyFrom(primaryCreatedAtCiphertext))
.build())
.addDevices(GetDevicesResponse.LinkedDevice.newBuilder()
.setId(Device.PRIMARY_ID + 1)
.setCreated(linkedDeviceCreated.toEpochMilli())
.setLastSeen(linkedDeviceLastSeen.toEpochMilli())
.setName(ByteString.copyFrom(linkedDeviceName.getBytes(StandardCharsets.UTF_8)))
.setRegistrationId(linkedRegistrationId)
.setCreatedAtCiphertext(ByteString.copyFrom(linkedCreatedAtCiphertext))
.build())
.build();

View File

@@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
@@ -536,7 +537,7 @@ public class AccountCreationDeletionIntegrationTest {
assertTrue(account.getRegistrationLock().verify(registrationLockSecret));
assertTrue(primaryDevice.getAuthTokenHash().verify(password));
assertNotNull(primaryDevice.getCreatedAtCiphertext());
assertEquals(Optional.of(aciSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(pniSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(aciPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());

View File

@@ -928,6 +928,7 @@ class AccountsManagerTest {
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(CLOCK.millis())));
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
account.setIdentityKey(new IdentityKey(ECKeyPair.generate().getPublicKey()));
final byte nextDeviceId = account.getNextDeviceId();

View File

@@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
@@ -474,6 +475,8 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
assertEquals(updatedAccountAndDevice.second().getRegistrationId(IdentityType.ACI), deviceInfo.registrationId());
assertNotNull(deviceInfo.createdAtCiphertext());
}
@Test
@@ -521,6 +524,8 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
assertEquals(updatedAccountAndDevice.second().getRegistrationId(IdentityType.ACI), deviceInfo.registrationId());
assertNotNull(deviceInfo.createdAtCiphertext());
}
@Test

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import java.nio.ByteBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.whispersystems.textsecuregcm.util.EncryptDeviceCreationTimestampUtil.ENCRYPTION_INFO;
public class EncryptDeviceCreationTimestampUtilTest {
@Test
void encryptDecrypt() throws InvalidMessageException {
final long createdAt = System.currentTimeMillis();
final ECKeyPair keyPair = ECKeyPair.generate();
final byte deviceId = 1;
final int registrationId = 123;
final byte[] ciphertext = EncryptDeviceCreationTimestampUtil.encrypt(createdAt, new IdentityKey(keyPair.getPublicKey()),
deviceId, registrationId);
final ByteBuffer associatedData = ByteBuffer.allocate(5);
associatedData.put(deviceId);
associatedData.putInt(registrationId);
final byte[] decryptedData = keyPair.getPrivateKey().open(ciphertext, ENCRYPTION_INFO, associatedData.array());
assertEquals(createdAt, ByteBuffer.wrap(decryptedData).getLong());
}
}