Convert prekey requests to WebSocket.

This commit is contained in:
Cody Henthorne
2025-03-14 18:17:23 -04:00
parent da3fc408f8
commit aeec3a6f7e
15 changed files with 276 additions and 253 deletions

View File

@@ -26,6 +26,7 @@ import org.thoughtcrime.securesms.testing.runSync
import org.thoughtcrime.securesms.testing.success
import org.whispersystems.signalservice.api.SignalServiceDataStore
import org.whispersystems.signalservice.api.SignalServiceMessageSender
import org.whispersystems.signalservice.api.keys.KeysApi
import org.whispersystems.signalservice.api.message.MessageApi
import org.whispersystems.signalservice.api.push.TrustStore
import org.whispersystems.signalservice.api.websocket.SignalWebSocket
@@ -125,10 +126,11 @@ class InstrumentationApplicationDependencyProvider(val application: Application,
authWebSocket: SignalWebSocket.AuthenticatedWebSocket,
protocolStore: SignalServiceDataStore,
pushServiceSocket: PushServiceSocket,
messageApi: MessageApi
messageApi: MessageApi,
keysApi: KeysApi
): SignalServiceMessageSender {
if (signalServiceMessageSender == null) {
signalServiceMessageSender = spyk(objToCopy = default.provideSignalServiceMessageSender(authWebSocket, protocolStore, pushServiceSocket, messageApi))
signalServiceMessageSender = spyk(objToCopy = default.provideSignalServiceMessageSender(authWebSocket, protocolStore, pushServiceSocket, messageApi, keysApi))
}
return signalServiceMessageSender!!
}

View File

@@ -161,7 +161,7 @@ class ChangeNumberRepository(
pniMetadataStore.activeSignedPreKeyId = signedPreKey.id
Log.i(TAG, "Submitting prekeys with PNI identity key: ${pniIdentityKeyPair.publicKey.fingerprint}")
accountManager.setPreKeys(
SignalNetwork.keys.setPreKeys(
PreKeyUpload(
serviceIdType = ServiceIdType.PNI,
signedPreKey = signedPreKey,
@@ -169,7 +169,7 @@ class ChangeNumberRepository(
lastResortKyberPreKey = lastResortKyberPreKey,
oneTimeKyberPreKeys = oneTimeKyberPreKeys
)
)
).successOrThrow()
pniMetadataStore.isSignedPreKeyRegistered = true
pniMetadataStore.lastResortKyberPreKeyId = pniLastResortKyberPreKeyId

View File

@@ -362,7 +362,7 @@ object AppDependencies {
fun providePushServiceSocket(signalServiceConfiguration: SignalServiceConfiguration, groupsV2Operations: GroupsV2Operations): PushServiceSocket
fun provideGroupsV2Operations(signalServiceConfiguration: SignalServiceConfiguration): GroupsV2Operations
fun provideSignalServiceAccountManager(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, accountApi: AccountApi, pushServiceSocket: PushServiceSocket, groupsV2Operations: GroupsV2Operations): SignalServiceAccountManager
fun provideSignalServiceMessageSender(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket, messageApi: MessageApi): SignalServiceMessageSender
fun provideSignalServiceMessageSender(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket, messageApi: MessageApi, keysApi: KeysApi): SignalServiceMessageSender
fun provideSignalServiceMessageReceiver(pushServiceSocket: PushServiceSocket): SignalServiceMessageReceiver
fun provideSignalServiceNetworkAccess(): SignalServiceNetworkAccess
fun provideRecipientCache(): LiveRecipientCache
@@ -397,7 +397,7 @@ object AppDependencies {
fun provideLibsignalNetwork(config: SignalServiceConfiguration): Network
fun provideBillingApi(): BillingApi
fun provideArchiveApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket, pushServiceSocket: PushServiceSocket): ArchiveApi
fun provideKeysApi(pushServiceSocket: PushServiceSocket): KeysApi
fun provideKeysApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket): KeysApi
fun provideAttachmentApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi
fun provideLinkDeviceApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket): LinkDeviceApi
fun provideRegistrationApi(pushServiceSocket: PushServiceSocket): RegistrationApi

View File

@@ -151,12 +151,17 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
}
@Override
public @NonNull SignalServiceMessageSender provideSignalServiceMessageSender(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket, @NonNull SignalServiceDataStore protocolStore, @NonNull PushServiceSocket pushServiceSocket, @NonNull MessageApi messageApi) {
public @NonNull SignalServiceMessageSender provideSignalServiceMessageSender(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket,
@NonNull SignalServiceDataStore protocolStore,
@NonNull PushServiceSocket pushServiceSocket,
@NonNull MessageApi messageApi,
@NonNull KeysApi keysApi) {
return new SignalServiceMessageSender(pushServiceSocket,
protocolStore,
ReentrantSessionLock.INSTANCE,
authWebSocket,
messageApi,
keysApi,
Optional.of(new SecurityEventListener(context)),
SignalExecutors.newCachedBoundedExecutor("signal-messages", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 16, 30),
ByteUnit.KILOBYTES.toBytes(256));
@@ -471,8 +476,8 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
}
@Override
public @NonNull KeysApi provideKeysApi(@NonNull PushServiceSocket pushServiceSocket) {
return new KeysApi(pushServiceSocket);
public @NonNull KeysApi provideKeysApi(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket, @NonNull SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket) {
return new KeysApi(authWebSocket, unauthWebSocket);
}
@Override

View File

@@ -82,7 +82,7 @@ class NetworkDependenciesModule(
val protocolStore: SignalServiceDataStoreImpl by _protocolStore
private val _signalServiceMessageSender = resettableLazy {
provider.provideSignalServiceMessageSender(authWebSocket, protocolStore, pushServiceSocket, messageApi)
provider.provideSignalServiceMessageSender(authWebSocket, protocolStore, pushServiceSocket, messageApi, keysApi)
}
val signalServiceMessageSender: SignalServiceMessageSender by _signalServiceMessageSender
@@ -146,7 +146,7 @@ class NetworkDependenciesModule(
}
val keysApi: KeysApi by lazy {
provider.provideKeysApi(pushServiceSocket)
provider.provideKeysApi(authWebSocket, unauthWebSocket)
}
val attachmentApi: AttachmentApi by lazy {

View File

@@ -18,18 +18,16 @@ import org.thoughtcrime.securesms.jobs.protos.PreKeysSyncJobData
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.net.SignalNetwork
import org.thoughtcrime.securesms.util.RemoteConfig
import org.thoughtcrime.securesms.util.isRetryableIOException
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.account.PreKeyUpload
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.ServiceIdType
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException
import org.whispersystems.signalservice.internal.push.OneTimePreKeyCounts
import java.io.IOException
import java.net.ProtocolException
import java.util.concurrent.TimeUnit
import kotlin.jvm.Throws
import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.DurationUnit
@@ -169,8 +167,7 @@ class PreKeysSyncJob private constructor(
return
}
val accountManager = AppDependencies.signalServiceAccountManager
val availablePreKeyCounts: OneTimePreKeyCounts = accountManager.getPreKeyCounts(serviceIdType)
val availablePreKeyCounts = SignalNetwork.keys.getAvailablePreKeyCounts(serviceIdType).successOrThrow()
val signedPreKeyToUpload: SignedPreKeyRecord? = signedPreKeyUploadIfNeeded(serviceIdType, protocolStore, metadataStore, forceRotation)
@@ -194,7 +191,7 @@ class PreKeysSyncJob private constructor(
if (signedPreKeyToUpload != null || oneTimeEcPreKeysToUpload != null || lastResortKyberPreKeyToUpload != null || oneTimeKyberPreKeysToUpload != null) {
log(serviceIdType, "Something to upload. SignedPreKey: ${signedPreKeyToUpload != null}, OneTimeEcPreKeys: ${oneTimeEcPreKeysToUpload != null}, LastResortKyberPreKey: ${lastResortKyberPreKeyToUpload != null}, OneTimeKyberPreKeys: ${oneTimeKyberPreKeysToUpload != null}")
accountManager.setPreKeys(
SignalNetwork.keys.setPreKeys(
PreKeyUpload(
serviceIdType = serviceIdType,
signedPreKey = signedPreKeyToUpload,
@@ -202,7 +199,7 @@ class PreKeysSyncJob private constructor(
lastResortKyberPreKey = lastResortKyberPreKeyToUpload,
oneTimeKyberPreKeys = oneTimeKyberPreKeysToUpload
)
)
).successOrThrow()
if (signedPreKeyToUpload != null) {
log(serviceIdType, "Successfully uploaded signed prekey.")
@@ -292,11 +289,7 @@ class PreKeysSyncJob private constructor(
}
override fun onShouldRetry(e: Exception): Boolean {
return when (e) {
is NonSuccessfulResponseCodeException -> false
is PushNetworkException -> true
else -> false
}
return e.isRetryableIOException()
}
override fun onFailure() {

View File

@@ -12,9 +12,10 @@ import org.thoughtcrime.securesms.dependencies.AppDependencies;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.keyvalue.SignalStore;
import org.thoughtcrime.securesms.net.SignalNetwork;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.whispersystems.signalservice.api.NetworkResultUtil;
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
import org.whispersystems.signalservice.api.SignalServiceAccountManager;
import org.whispersystems.signalservice.api.account.PreKeyUpload;
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
import org.whispersystems.signalservice.api.push.ServiceIdType;
@@ -70,7 +71,6 @@ public class PniAccountInitializationMigrationJob extends MigrationJob {
Log.w(TAG, "Already generated the PNI identity. Skipping this step.");
}
SignalServiceAccountManager accountManager = AppDependencies.getSignalServiceAccountManager();
SignalServiceAccountDataStore protocolStore = AppDependencies.getProtocolStore().pni();
PreKeyMetadataStore metadataStore = SignalStore.account().pniPreKeys();
@@ -79,7 +79,7 @@ public class PniAccountInitializationMigrationJob extends MigrationJob {
SignedPreKeyRecord signedPreKey = PreKeyUtil.generateAndStoreSignedPreKey(protocolStore, metadataStore);
List<PreKeyRecord> oneTimePreKeys = PreKeyUtil.generateAndStoreOneTimeEcPreKeys(protocolStore, metadataStore);
accountManager.setPreKeys(new PreKeyUpload(ServiceIdType.PNI, signedPreKey, oneTimePreKeys, null, null));
NetworkResultUtil.toPreKeysLegacy(SignalNetwork.keys().setPreKeys(new PreKeyUpload(ServiceIdType.PNI, signedPreKey, oneTimePreKeys, null, null)));
metadataStore.setActiveSignedPreKeyId(signedPreKey.getId());
metadataStore.setSignedPreKeyRegistered(true);
} else {

View File

@@ -42,6 +42,8 @@ object SignalNetwork {
val cdsApi: CdsApi
get() = AppDependencies.cdsApi
@JvmStatic
@get:JvmName("keys")
val keys: KeysApi
get() = AppDependencies.keysApi

View File

@@ -74,7 +74,8 @@ class MockApplicationDependencyProvider : AppDependencies.Provider {
authWebSocket: SignalWebSocket.AuthenticatedWebSocket,
protocolStore: SignalServiceDataStore,
pushServiceSocket: PushServiceSocket,
messageApi: MessageApi
messageApi: MessageApi,
keysApi: KeysApi
): SignalServiceMessageSender {
return mockk(relaxed = true)
}
@@ -220,7 +221,7 @@ class MockApplicationDependencyProvider : AppDependencies.Provider {
return mockk(relaxed = true)
}
override fun provideKeysApi(pushServiceSocket: PushServiceSocket): KeysApi {
override fun provideKeysApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket): KeysApi {
return mockk(relaxed = true)
}

View File

@@ -130,4 +130,28 @@ object NetworkResultUtil {
}
}
}
@JvmStatic
@Throws(IOException::class)
fun <T> toPreKeysLegacy(result: NetworkResult<T>): T {
return when (result) {
is NetworkResult.Success -> result.result
is NetworkResult.StatusCodeError -> {
throw when (result.code) {
400, 401 -> AuthorizationFailedException(result.code, "Authorization failed!")
404 -> NotFoundException("Not found")
429 -> RateLimitException(result.code, "Rate limit exceeded: ${result.code}", Optional.empty())
508 -> ServerRejectedException()
else -> result.exception
}
}
is NetworkResult.NetworkError -> throw result.exception
is NetworkResult.ApplicationError -> {
throw when (val error = result.throwable) {
is IOException, is RuntimeException -> error
else -> RuntimeException(error)
}
}
}
}
}

View File

@@ -10,7 +10,6 @@ import org.signal.libsignal.net.Network;
import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredential;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.account.AccountApi;
import org.whispersystems.signalservice.api.account.PreKeyUpload;
import org.whispersystems.signalservice.api.crypto.ProfileCipher;
import org.whispersystems.signalservice.api.crypto.ProfileCipherOutputStream;
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess;
@@ -22,7 +21,6 @@ import org.whispersystems.signalservice.api.profiles.ProfileAndCredential;
import org.whispersystems.signalservice.api.profiles.SignalServiceProfileWrite;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException;
import org.whispersystems.signalservice.api.registration.RegistrationApi;
@@ -30,7 +28,6 @@ import org.whispersystems.signalservice.api.svr.SecureValueRecoveryV2;
import org.whispersystems.signalservice.api.svr.SecureValueRecoveryV3;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.push.OneTimePreKeyCounts;
import org.whispersystems.signalservice.internal.push.PaymentAddress;
import org.whispersystems.signalservice.internal.push.ProfileAvatarData;
import org.whispersystems.signalservice.internal.push.PushServiceSocket;
@@ -145,26 +142,6 @@ public class SignalServiceAccountManager {
pushServiceSocket.requestPushChallenge(sessionId, gcmRegistrationId);
}
/**
* Register an identity key, signed prekey, and list of one time prekeys
* with the server.
*
* @throws IOException
*/
public void setPreKeys(PreKeyUpload preKeyUpload)
throws IOException
{
this.pushServiceSocket.registerPreKeys(preKeyUpload);
}
/**
* @return The server's count of currently available (eg. unused) prekeys for this user.
* @throws IOException
*/
public OneTimePreKeyCounts getPreKeyCounts(ServiceIdType serviceIdType) throws IOException {
return this.pushServiceSocket.getAvailablePreKeys(serviceIdType);
}
public RemoteConfigResult getRemoteConfig() throws IOException {
RemoteConfigResponse response = this.pushServiceSocket.getRemoteConfig();
Map<String, Object> out = new HashMap<>();

View File

@@ -31,6 +31,7 @@ import org.whispersystems.signalservice.api.crypto.SignalSessionBuilder;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.crypto.UntrustedIdentityException;
import org.whispersystems.signalservice.api.groupsv2.GroupSendEndorsements;
import org.whispersystems.signalservice.api.keys.KeysApi;
import org.whispersystems.signalservice.api.message.MessageApi;
import org.whispersystems.signalservice.api.messages.SendMessageResult;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment;
@@ -85,7 +86,6 @@ import org.whispersystems.signalservice.api.util.Uint64Util;
import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.crypto.AttachmentDigest;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.AttachmentPointer;
@@ -177,6 +177,7 @@ public class SignalServiceMessageSender {
private final AttachmentService attachmentService;
private final MessageApi messageApi;
private final KeysApi keysApi;
private final Scheduler scheduler;
private final long maxEnvelopeSize;
@@ -186,6 +187,7 @@ public class SignalServiceMessageSender {
SignalSessionLock sessionLock,
SignalWebSocket.AuthenticatedWebSocket authWebSocket,
MessageApi messageApi,
KeysApi keysApi,
Optional<EventListener> eventListener,
ExecutorService executor,
long maxEnvelopeSize)
@@ -204,6 +206,7 @@ public class SignalServiceMessageSender {
this.maxEnvelopeSize = maxEnvelopeSize;
this.localPniIdentity = store.pni().getIdentityKeyPair();
this.scheduler = Schedulers.from(executor, false, false);
this.keysApi = keysApi;
}
/**
@@ -1983,7 +1986,7 @@ public class SignalServiceMessageSender {
}
} catch (MismatchedDevicesException mde) {
Log.w(TAG, "[sendMessage][" + timestamp + "] Handling mismatched devices. (" + mde.getMessage() + ")");
handleMismatchedDevices(socket, recipient, mde.getMismatchedDevices());
handleMismatchedDevices(recipient, mde.getMismatchedDevices());
} catch (StaleDevicesException ste) {
Log.w(TAG, "[sendMessage][" + timestamp + "] Handling stale devices. (" + ste.getMessage() + ")");
handleStaleDevices(recipient, ste.getStaleDevices());
@@ -2228,7 +2231,7 @@ public class SignalServiceMessageSender {
Log.w(TAG, "[sendMessage][" + timestamp + "] Handling mismatched devices. (" + mde.getMessage() + ")");
return Single.fromCallable(() -> {
handleMismatchedDevices(socket, recipient, mde.getMismatchedDevices());
handleMismatchedDevices(recipient, mde.getMismatchedDevices());
return Unit.INSTANCE;
})
.flatMap(unused -> sendMessageRx(
@@ -2451,7 +2454,7 @@ public class SignalServiceMessageSender {
Log.w(TAG, "[sendGroupMessage][" + timestamp + "] Handling mismatched devices. (" + e.getMessage() + ")");
for (GroupMismatchedDevices mismatched : e.getMismatchedDevices()) {
SignalServiceAddress address = new SignalServiceAddress(ServiceId.parseOrThrow(mismatched.getUuid()), Optional.empty());
handleMismatchedDevices(socket, address, mismatched.getDevices());
handleMismatchedDevices(address, mismatched.getDevices());
}
} catch (GroupStaleDevicesException e) {
Log.w(TAG, "[sendGroupMessage][" + timestamp + "] Handling stale devices. (" + e.getMessage() + ")");
@@ -2703,19 +2706,18 @@ public class SignalServiceMessageSender {
sealedSenderAccess = null;
}
return socket.getPreKeys(recipient, sealedSenderAccess, deviceId);
return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeys(recipient, sealedSenderAccess, deviceId));
} catch (NonSuccessfulResponseCodeException e) {
if (e.code == 401 && story) {
Log.d(TAG, "Got 401 when fetching prekey for story. Trying without UD.");
return socket.getPreKeys(recipient, null, deviceId);
return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeys(recipient, null, deviceId));
} else {
throw e;
}
}
}
private void handleMismatchedDevices(PushServiceSocket socket,
SignalServiceAddress recipient,
private void handleMismatchedDevices(SignalServiceAddress recipient,
MismatchedDevices mismatchedDevices)
throws IOException, UntrustedIdentityException
{
@@ -2724,7 +2726,7 @@ public class SignalServiceMessageSender {
archiveSessions(recipient, mismatchedDevices.getExtraDevices());
for (int missingDeviceId : mismatchedDevices.getMissingDevices()) {
PreKeyBundle preKey = socket.getPreKey(recipient, missingDeviceId);
PreKeyBundle preKey = NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKey(recipient, missingDeviceId));
try {
SignalSessionBuilder sessionBuilder = new SignalSessionBuilder(sessionLock, new SessionBuilder(aciStore, new SignalProtocolAddress(recipient.getIdentifier(), missingDeviceId)));
@@ -2746,7 +2748,7 @@ public class SignalServiceMessageSender {
public void handleChangeNumberMismatchDevices(@Nonnull MismatchedDevices mismatchedDevices)
throws IOException, UntrustedIdentityException
{
handleMismatchedDevices(socket, localAddress, mismatchedDevices);
handleMismatchedDevices(localAddress, mismatchedDevices);
}
private void archiveSessions(SignalServiceAddress recipient, List<Integer> devices) {

View File

@@ -5,25 +5,44 @@
package org.whispersystems.signalservice.api.keys
import org.signal.core.util.logging.Log
import org.signal.core.util.toByteArray
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.protocol.kem.KEMPublicKey
import org.signal.libsignal.protocol.state.PreKeyBundle
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.account.PreKeyUpload
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess
import org.whispersystems.signalservice.api.push.ServiceIdType
import org.whispersystems.signalservice.internal.push.PushServiceSocket
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.api.push.SignedPreKeyEntity
import org.whispersystems.signalservice.api.push.exceptions.UnregisteredUserException
import org.whispersystems.signalservice.api.websocket.SignalWebSocket
import org.whispersystems.signalservice.internal.get
import org.whispersystems.signalservice.internal.post
import org.whispersystems.signalservice.internal.push.CheckRepeatedUsedPreKeysRequest
import org.whispersystems.signalservice.internal.push.KyberPreKeyEntity
import org.whispersystems.signalservice.internal.push.PreKeyEntity
import org.whispersystems.signalservice.internal.push.PreKeyResponse
import org.whispersystems.signalservice.internal.push.PreKeyState
import org.whispersystems.signalservice.internal.put
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import java.io.IOException
import java.security.MessageDigest
import java.util.LinkedList
/**
* Contains APIs for interacting with /keys endpoints on the service.
*/
class KeysApi(private val pushServiceSocket: PushServiceSocket) {
class KeysApi(
private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket,
private val unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket
) {
companion object {
@JvmStatic
fun create(pushServiceSocket: PushServiceSocket): KeysApi {
return KeysApi(pushServiceSocket)
}
private val TAG = Log.tag(KeysApi::class)
}
/**
@@ -50,8 +69,186 @@ class KeysApi(private val pushServiceSocket: PushServiceSocket) {
update(lastResortKyberKey.serialize())
}
return NetworkResult.fromFetch {
pushServiceSocket.checkRepeatedUsePreKeys(serviceIdType, digest.digest())
val body = CheckRepeatedUsedPreKeysRequest(serviceIdType.toString(), digest.digest())
val request = WebSocketRequestMessage.post("/v2/keys/check", body)
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
/**
* The server's count of currently available (eg. unused) prekeys for this user.
*
* GET /v2/keys?identity=[serviceIdType]
* - 200: Success
*/
fun getAvailablePreKeyCounts(serviceIdType: ServiceIdType): NetworkResult<OneTimePreKeyCounts> {
val request = WebSocketRequestMessage.get("/v2/keys?identity=${serviceIdType.queryParam()}")
return NetworkResult.fromWebSocketRequest(authWebSocket, request, OneTimePreKeyCounts::class)
}
/**
* Register an identity key, signed prekey, and list of one time prekeys with the server.
*
* PUT /v2/keys?identity=[preKeyUpload]`.serviceIdType`
*/
fun setPreKeys(preKeyUpload: PreKeyUpload): NetworkResult<Unit> {
val signedPreKey: SignedPreKeyEntity? = if (preKeyUpload.signedPreKey != null) {
SignedPreKeyEntity(
preKeyUpload.signedPreKey.id,
preKeyUpload.signedPreKey.keyPair.publicKey,
preKeyUpload.signedPreKey.signature
)
} else {
null
}
val oneTimeEcPreKeys: List<PreKeyEntity>? = if (preKeyUpload.oneTimeEcPreKeys != null) {
preKeyUpload
.oneTimeEcPreKeys
.map { oneTimeEcKey: PreKeyRecord -> PreKeyEntity(oneTimeEcKey.id, oneTimeEcKey.keyPair.publicKey) }
} else {
null
}
val lastResortKyberPreKey: KyberPreKeyEntity? = if (preKeyUpload.lastResortKyberPreKey != null) {
KyberPreKeyEntity(
preKeyUpload.lastResortKyberPreKey.id,
preKeyUpload.lastResortKyberPreKey.keyPair.publicKey,
preKeyUpload.lastResortKyberPreKey.signature
)
} else {
null
}
val oneTimeKyberPreKeys: List<KyberPreKeyEntity>? = if (preKeyUpload.oneTimeKyberPreKeys != null) {
preKeyUpload
.oneTimeKyberPreKeys
.map { record -> KyberPreKeyEntity(record.id, record.keyPair.publicKey, record.signature) }
} else {
null
}
val body = PreKeyState(signedPreKey, oneTimeEcPreKeys, lastResortKyberPreKey, oneTimeKyberPreKeys)
val request = WebSocketRequestMessage.put("/v2/keys?identity=${preKeyUpload.serviceIdType.queryParam()}", body)
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
/**
* Retrieves prekeys. If the specified device is the primary (i.e. deviceId 1), it will retrieve prekeys
* for all devices. If it is not a primary, it will only contain the prekeys for that specific device.
*
* GET /v2/keys/[destination]`.identifier`/[deviceSpecifier]
* - 200: Success
* - 400: Multiple forms of authentication provided
* - 401: No valid authentication provided
* - 404: No keys found for address/device
* - 429: Rate limited
*/
fun getPreKeys(
destination: SignalServiceAddress,
sealedSenderAccess: SealedSenderAccess?,
deviceId: Int
): NetworkResult<List<PreKeyBundle>> {
return getPreKeysBySpecifier(destination, sealedSenderAccess, if (deviceId == 1) "*" else deviceId.toString())
}
/**
* Retrieves a prekey for a specific device.
*
* GET /v2/keys/[destination]`.identifier`/[deviceSpecifier]
* - 200: Success
* - 400: Multiple forms of authentication provided
* - 401: No valid authentication provided
* - 404: No keys found for address/device
* - 429: Rate limited
*/
fun getPreKey(destination: SignalServiceAddress, deviceId: Int): NetworkResult<PreKeyBundle> {
return getPreKeysBySpecifier(destination, null, deviceId.toString())
.then { bundles ->
if (bundles.isNotEmpty()) {
NetworkResult.Success(bundles[0])
} else {
NetworkResult.NetworkError(IOException("No prekeys available!"))
}
}
}
/**
* Retrieves the public identity key and available device prekeys for the specified [destination]. Results can
* be restricted to a specific device setting the device number for [deviceSpecifier] or can get all devices by passing `*`.
*
* GET /v2/keys/[destination]`.identifier`/[deviceSpecifier]
* - 200: Success
* - 400: Multiple forms of authentication provided
* - 401: No valid authentication provided
* - 404: No keys found for address/device
* - 429: Rate limited
*/
private fun getPreKeysBySpecifier(destination: SignalServiceAddress, sealedSenderAccess: SealedSenderAccess?, deviceSpecifier: String): NetworkResult<List<PreKeyBundle>> {
val request = WebSocketRequestMessage.get("/v2/keys/${destination.identifier}/$deviceSpecifier")
Log.d(TAG, "Fetching prekeys for ${destination.identifier}.$deviceSpecifier, i.e. GET ${request.path}")
val result: NetworkResult<PreKeyResponse> = NetworkResult.fromWebSocket {
if (sealedSenderAccess != null) {
unauthWebSocket.request(request, sealedSenderAccess)
} else {
authWebSocket.request(request)
}
}
if (result is NetworkResult.StatusCodeError && result.code == 404) {
return NetworkResult.NetworkError(UnregisteredUserException(destination.identifier, result.exception))
}
return result.map { response ->
val bundles: MutableList<PreKeyBundle> = LinkedList()
for (device in response.getDevices()) {
var preKey: ECPublicKey? = null
var signedPreKey: ECPublicKey? = null
var signedPreKeySignature: ByteArray? = null
var preKeyId = PreKeyBundle.NULL_PRE_KEY_ID
var signedPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID
var kyberPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID
var kyberPreKey: KEMPublicKey? = null
var kyberPreKeySignature: ByteArray? = null
if (device.getSignedPreKey() != null) {
signedPreKey = device.getSignedPreKey().publicKey
signedPreKeyId = device.getSignedPreKey().keyId
signedPreKeySignature = device.getSignedPreKey().signature
}
if (device.getPreKey() != null) {
preKeyId = device.getPreKey().keyId
preKey = device.getPreKey().publicKey
}
if (device.getKyberPreKey() != null) {
kyberPreKey = device.getKyberPreKey().publicKey
kyberPreKeyId = device.getKyberPreKey().keyId
kyberPreKeySignature = device.getKyberPreKey().signature
}
bundles.add(
PreKeyBundle(
device.getRegistrationId(),
device.getDeviceId(),
preKeyId,
preKey,
signedPreKeyId,
signedPreKey,
signedPreKeySignature,
response.getIdentityKey(),
kyberPreKeyId,
kyberPreKey,
kyberPreKeySignature
)
)
}
bundles
}
}
}

View File

@@ -1,10 +1,9 @@
/**
* Copyright (C) 2014-2016 Open Whisper Systems
*
* Licensed according to the LICENSE file in this repository.
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal.push;
package org.whispersystems.signalservice.api.keys;
import com.fasterxml.jackson.annotation.JsonProperty;

View File

@@ -15,10 +15,7 @@ import org.signal.core.util.concurrent.FutureTransformers;
import org.signal.core.util.concurrent.ListenableFuture;
import org.signal.core.util.concurrent.SettableFuture;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.protocol.state.PreKeyBundle;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
@@ -41,7 +38,6 @@ import org.signal.storageservice.protos.groups.GroupResponse;
import org.signal.storageservice.protos.groups.Member;
import org.whispersystems.signalservice.api.account.AccountAttributes;
import org.whispersystems.signalservice.api.account.PreKeyCollection;
import org.whispersystems.signalservice.api.account.PreKeyUpload;
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess;
import org.whispersystems.signalservice.api.groupsv2.GroupsV2AuthorizationString;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
@@ -51,7 +47,6 @@ import org.whispersystems.signalservice.api.profiles.ProfileAndCredential;
import org.whispersystems.signalservice.api.profiles.SignalServiceProfile;
import org.whispersystems.signalservice.api.profiles.SignalServiceProfileWrite;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.SignedPreKeyEntity;
import org.whispersystems.signalservice.api.push.exceptions.AlreadyVerifiedException;
@@ -156,7 +151,6 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -191,11 +185,6 @@ public class PushServiceSocket {
private static final String DELETE_ACCOUNT_PATH = "/v1/accounts/me";
private static final String PREKEY_METADATA_PATH = "/v2/keys?identity=%s";
private static final String PREKEY_PATH = "/v2/keys?identity=%s";
private static final String PREKEY_DEVICE_PATH = "/v2/keys/%s/%s";
private static final String PREKEY_CHECK_PATH = "/v2/keys/check";
private static final String PROVISIONING_MESSAGE_PATH = "/v1/provisioning/%s";
private static final String SET_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s";
private static final String WAIT_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s?timeout=%s";
@@ -513,174 +502,6 @@ public class PushServiceSocket {
}
}
public void registerPreKeys(PreKeyUpload preKeyUpload)
throws IOException
{
SignedPreKeyEntity signedPreKey = null;
List<PreKeyEntity> oneTimeEcPreKeys = null;
KyberPreKeyEntity lastResortKyberPreKey = null;
List<KyberPreKeyEntity> oneTimeKyberPreKeys = null;
try {
if (preKeyUpload.getSignedPreKey() != null) {
signedPreKey = new SignedPreKeyEntity(preKeyUpload.getSignedPreKey().getId(),
preKeyUpload.getSignedPreKey().getKeyPair().getPublicKey(),
preKeyUpload.getSignedPreKey().getSignature());
}
} catch (InvalidKeyException e) {
throw new AssertionError("unexpected invalid key", e);
}
if (preKeyUpload.getOneTimeEcPreKeys() != null) {
oneTimeEcPreKeys = preKeyUpload
.getOneTimeEcPreKeys()
.stream()
.map(it -> {
try {
return new PreKeyEntity(it.getId(), it.getKeyPair().getPublicKey());
} catch (InvalidKeyException e) {
throw new AssertionError("unexpected invalid key", e);
}
})
.collect(Collectors.toList());
}
if (preKeyUpload.getLastResortKyberPreKey() != null) {
try {
lastResortKyberPreKey = new KyberPreKeyEntity(preKeyUpload.getLastResortKyberPreKey().getId(),
preKeyUpload.getLastResortKyberPreKey().getKeyPair().getPublicKey(),
preKeyUpload.getLastResortKyberPreKey().getSignature());
} catch (InvalidKeyException e) {
throw new AssertionError("unexpected invalid key", e);
}
}
if (preKeyUpload.getOneTimeKyberPreKeys() != null) {
oneTimeKyberPreKeys = preKeyUpload
.getOneTimeKyberPreKeys()
.stream()
.map(it -> {
try {
return new KyberPreKeyEntity(it.getId(), it.getKeyPair().getPublicKey(), it.getSignature());
} catch (InvalidKeyException e) {
throw new AssertionError("unexpected invalid key", e);
}
})
.collect(Collectors.toList());
}
makeServiceRequest(String.format(Locale.US, PREKEY_PATH, preKeyUpload.getServiceIdType().queryParam()),
"PUT",
JsonUtil.toJson(new PreKeyState(signedPreKey,
oneTimeEcPreKeys,
lastResortKyberPreKey,
oneTimeKyberPreKeys)));
}
public OneTimePreKeyCounts getAvailablePreKeys(ServiceIdType serviceIdType) throws IOException {
String path = String.format(PREKEY_METADATA_PATH, serviceIdType.queryParam());
String responseText = makeServiceRequest(path, "GET", null);
OneTimePreKeyCounts preKeyStatus = JsonUtil.fromJson(responseText, OneTimePreKeyCounts.class);
return preKeyStatus;
}
/**
* Retrieves prekeys. If the specified device is the primary (i.e. deviceId 1), it will retrieve prekeys
* for all devices. If it is not a primary, it will only contain the prekeys for that specific device.
*/
public List<PreKeyBundle> getPreKeys(SignalServiceAddress destination,
@Nullable SealedSenderAccess sealedSenderAccess,
int deviceId)
throws IOException
{
return getPreKeysBySpecifier(destination, sealedSenderAccess, deviceId == 1 ? "*" : String.valueOf(deviceId));
}
/**
* Retrieves a prekey for a specific device.
*/
public PreKeyBundle getPreKey(SignalServiceAddress destination, int deviceId) throws IOException {
List<PreKeyBundle> bundles = getPreKeysBySpecifier(destination, null, String.valueOf(deviceId));
if (bundles.size() > 0) {
return bundles.get(0);
} else {
throw new IOException("No prekeys available!");
}
}
private List<PreKeyBundle> getPreKeysBySpecifier(SignalServiceAddress destination,
@Nullable SealedSenderAccess sealedSenderAccess,
String deviceSpecifier)
throws IOException
{
try {
String path = String.format(PREKEY_DEVICE_PATH, destination.getIdentifier(), deviceSpecifier);
Log.d(TAG, "Fetching prekeys for " + destination.getIdentifier() + "." + deviceSpecifier + ", i.e. GET " + path);
String responseText = makeServiceRequest(path, "GET", null, NO_HEADERS, NO_HANDLER, sealedSenderAccess);
PreKeyResponse response = JsonUtil.fromJson(responseText, PreKeyResponse.class);
List<PreKeyBundle> bundles = new LinkedList<>();
for (PreKeyResponseItem device : response.getDevices()) {
ECPublicKey preKey = null;
ECPublicKey signedPreKey = null;
byte[] signedPreKeySignature = null;
int preKeyId = PreKeyBundle.NULL_PRE_KEY_ID;
int signedPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID;
int kyberPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID;
KEMPublicKey kyberPreKey = null;
byte[] kyberPreKeySignature = null;
if (device.getSignedPreKey() != null) {
signedPreKey = device.getSignedPreKey().getPublicKey();
signedPreKeyId = device.getSignedPreKey().getKeyId();
signedPreKeySignature = device.getSignedPreKey().getSignature();
}
if (device.getPreKey() != null) {
preKeyId = device.getPreKey().getKeyId();
preKey = device.getPreKey().getPublicKey();
}
if (device.getKyberPreKey() != null) {
kyberPreKey = device.getKyberPreKey().getPublicKey();
kyberPreKeyId = device.getKyberPreKey().getKeyId();
kyberPreKeySignature = device.getKyberPreKey().getSignature();
}
bundles.add(new PreKeyBundle(device.getRegistrationId(),
device.getDeviceId(),
preKeyId,
preKey,
signedPreKeyId,
signedPreKey,
signedPreKeySignature,
response.getIdentityKey(),
kyberPreKeyId,
kyberPreKey,
kyberPreKeySignature));
}
return bundles;
} catch (NotFoundException nfe) {
throw new UnregisteredUserException(destination.getIdentifier(), nfe);
}
}
public void checkRepeatedUsePreKeys(ServiceIdType serviceIdType, byte[] digest) throws IOException {
String body = JsonUtil.toJson(new CheckRepeatedUsedPreKeysRequest(serviceIdType.toString(), digest));
makeServiceRequest(PREKEY_CHECK_PATH, "POST", body, NO_HEADERS, (responseCode, errorBody, getHeader) -> {
// Must override this handling because otherwise code assumes a device mismatch error
if (responseCode == 409) {
throw new NonSuccessfulResponseCodeException(409);
}
}, null);
}
public void retrieveBackup(int cdnNumber, Map<String, String> headers, String cdnPath, File destination, long maxSizeBytes, ProgressListener listener)
throws MissingConfigurationException, IOException
{