mirror of
https://github.com/signalapp/Signal-Android.git
synced 2026-04-26 19:56:02 +01:00
Update registration for new restore flows.
This commit is contained in:
committed by
Greyson Parrelli
parent
aad2624bd5
commit
22c4e2d084
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.signalservice.api
|
||||
|
||||
import okhttp3.ConnectionSpec
|
||||
import okhttp3.OkHttpClient
|
||||
import org.whispersystems.signalservice.api.push.TrustStore
|
||||
import org.whispersystems.signalservice.api.util.Tls12SocketFactory
|
||||
import org.whispersystems.signalservice.api.util.TlsProxySocketFactory
|
||||
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration
|
||||
import org.whispersystems.signalservice.internal.configuration.SignalUrl
|
||||
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager
|
||||
import org.whispersystems.signalservice.internal.util.Util
|
||||
import java.security.KeyManagementException
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import java.util.concurrent.TimeUnit
|
||||
import javax.net.ssl.SSLContext
|
||||
import javax.net.ssl.SSLSocketFactory
|
||||
import javax.net.ssl.X509TrustManager
|
||||
|
||||
/**
|
||||
* Select a a URL at random to use.
|
||||
*/
|
||||
fun <T : SignalUrl> Array<T>.chooseUrl(): T {
|
||||
return this[(Math.random() * size).toInt()]
|
||||
}
|
||||
|
||||
/**
|
||||
* Build and configure an [OkHttpClient] as defined by the target [SignalUrl] and provided [configuration].
|
||||
*/
|
||||
fun <T : SignalUrl> T.buildOkHttpClient(configuration: SignalServiceConfiguration): OkHttpClient {
|
||||
val (socketFactory, trustManager) = createTlsSocketFactory(this.trustStore)
|
||||
|
||||
val builder = OkHttpClient.Builder()
|
||||
.sslSocketFactory(socketFactory, trustManager)
|
||||
.connectionSpecs(this.connectionSpecs.orElse(Util.immutableList(ConnectionSpec.RESTRICTED_TLS)))
|
||||
.retryOnConnectionFailure(false)
|
||||
.readTimeout(30, TimeUnit.SECONDS)
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
|
||||
for (interceptor in configuration.networkInterceptors) {
|
||||
builder.addInterceptor(interceptor)
|
||||
}
|
||||
|
||||
if (configuration.signalProxy.isPresent) {
|
||||
val proxy = configuration.signalProxy.get()
|
||||
builder.socketFactory(TlsProxySocketFactory(proxy.host, proxy.port, configuration.dns))
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
private fun createTlsSocketFactory(trustStore: TrustStore): Pair<SSLSocketFactory, X509TrustManager> {
|
||||
return try {
|
||||
val context = SSLContext.getInstance("TLS")
|
||||
val trustManagers = BlacklistingTrustManager.createFor(trustStore)
|
||||
context.init(null, trustManagers, null)
|
||||
Tls12SocketFactory(context.socketFactory) to trustManagers[0] as X509TrustManager
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
} catch (e: KeyManagementException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.signalservice.api.registration
|
||||
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.CoroutineExceptionHandler
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.plus
|
||||
import okhttp3.Request
|
||||
import okhttp3.Response
|
||||
import okhttp3.WebSocket
|
||||
import okhttp3.WebSocketListener
|
||||
import okio.ByteString
|
||||
import okio.ByteString.Companion.toByteString
|
||||
import org.signal.core.util.Base64
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.libsignal.protocol.IdentityKeyPair
|
||||
import org.signal.registration.proto.RegistrationProvisionEnvelope
|
||||
import org.whispersystems.signalservice.api.buildOkHttpClient
|
||||
import org.whispersystems.signalservice.api.chooseUrl
|
||||
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration
|
||||
import org.whispersystems.signalservice.internal.crypto.SecondaryProvisioningCipher
|
||||
import org.whispersystems.signalservice.internal.push.ProvisioningAddress
|
||||
import org.whispersystems.signalservice.internal.websocket.WebSocketMessage
|
||||
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
|
||||
import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMessage
|
||||
import java.io.Closeable
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import java.net.URLEncoder
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
/**
|
||||
* A provisional web socket for communicating with a primary device during registration.
|
||||
*/
|
||||
class ProvisioningSocket private constructor(
|
||||
identityKeyPair: IdentityKeyPair,
|
||||
configuration: SignalServiceConfiguration,
|
||||
private val scope: CoroutineScope
|
||||
) {
|
||||
companion object {
|
||||
private val TAG = Log.tag(ProvisioningSocket::class)
|
||||
|
||||
fun start(
|
||||
identityKeyPair: IdentityKeyPair,
|
||||
configuration: SignalServiceConfiguration,
|
||||
handler: CoroutineExceptionHandler,
|
||||
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
|
||||
): Closeable {
|
||||
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler
|
||||
|
||||
scope.launch {
|
||||
var socket: ProvisioningSocket? = null
|
||||
try {
|
||||
socket = ProvisioningSocket(identityKeyPair, configuration, scope)
|
||||
socket.connect()
|
||||
block(socket)
|
||||
} catch (e: CancellationException) {
|
||||
val rootCause = e.getRootCause()
|
||||
if (rootCause == null) {
|
||||
Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
|
||||
throw e
|
||||
} else {
|
||||
Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
|
||||
throw rootCause
|
||||
}
|
||||
} finally {
|
||||
Log.d(TAG, "Closing web socket")
|
||||
socket?.close()
|
||||
}
|
||||
}
|
||||
|
||||
return Closeable { scope.cancel("scope closed") }
|
||||
}
|
||||
|
||||
/**
|
||||
* Get non-cancellation exception cause to determine if something legitimately failed.
|
||||
*/
|
||||
private fun CancellationException.getRootCause(): Throwable? {
|
||||
var cause: Throwable? = cause
|
||||
while (cause != null && cause is CancellationException) {
|
||||
cause = cause.cause
|
||||
}
|
||||
return cause
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a minimal throwable informational string since stack traces aren't always logged.
|
||||
*/
|
||||
private fun Throwable.toMinimalString(): String {
|
||||
return "${javaClass.simpleName}[$message]"
|
||||
}
|
||||
}
|
||||
|
||||
private val serviceUrl = configuration.signalServiceUrls.chooseUrl()
|
||||
private val okhttp = serviceUrl.buildOkHttpClient(configuration)
|
||||
|
||||
private val cipher = SecondaryProvisioningCipher(identityKeyPair)
|
||||
private var webSocket: WebSocket? = null
|
||||
|
||||
private val provisioningUrlDeferral: CompletableDeferred<String> = CompletableDeferred()
|
||||
private val provisioningMessageDeferral: CompletableDeferred<SecondaryProvisioningCipher.RegistrationProvisionResult> = CompletableDeferred()
|
||||
|
||||
suspend fun getProvisioningUrl(): String {
|
||||
return provisioningUrlDeferral.await()
|
||||
}
|
||||
|
||||
suspend fun getRegistrationProvisioningMessage(): SecondaryProvisioningCipher.RegistrationProvisionResult {
|
||||
return provisioningMessageDeferral.await()
|
||||
}
|
||||
|
||||
private fun connect() {
|
||||
val uri = serviceUrl.url.replace("https://", "wss://").replace("http://", "ws://")
|
||||
|
||||
val openRequest = Request.Builder()
|
||||
.url("$uri/v1/websocket/provisioning/")
|
||||
|
||||
if (serviceUrl.hostHeader.isPresent) {
|
||||
openRequest.addHeader("Host", serviceUrl.hostHeader.get())
|
||||
Log.w(TAG, "Using alternate host: ${serviceUrl.hostHeader.get()}")
|
||||
}
|
||||
|
||||
webSocket = okhttp.newWebSocket(openRequest.build(), ProvisioningWebSocketListener())
|
||||
}
|
||||
|
||||
private fun close() {
|
||||
webSocket?.close(1000, "Manual shutdown")
|
||||
}
|
||||
|
||||
private inner class ProvisioningWebSocketListener : WebSocketListener() {
|
||||
private var keepAliveJob: Job? = null
|
||||
|
||||
@Volatile
|
||||
private var lastKeepAliveId: Long = 0
|
||||
|
||||
override fun onOpen(webSocket: WebSocket, response: Response) {
|
||||
Log.d(TAG, "[onOpen]")
|
||||
keepAliveJob = scope.launch { keepAlive(webSocket) }
|
||||
|
||||
val timeoutJob = scope.launch {
|
||||
delay(10.seconds)
|
||||
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
|
||||
}
|
||||
|
||||
scope.launch {
|
||||
provisioningUrlDeferral.await()
|
||||
timeoutJob.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
|
||||
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)
|
||||
|
||||
if (message.response != null && message.response.id == lastKeepAliveId) {
|
||||
Log.d(TAG, "[onMessage] Keep alive received")
|
||||
return
|
||||
}
|
||||
|
||||
if (message.request == null) {
|
||||
Log.w(TAG, "[onMessage] Received null request")
|
||||
return
|
||||
}
|
||||
|
||||
val success = webSocket.send(message.request.toResponse().encode().toByteString())
|
||||
|
||||
if (!success) {
|
||||
Log.w(TAG, "[onMessage] Failed to send response")
|
||||
webSocket.close(1000, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
Log.d(TAG, "[onMessage] Processing request")
|
||||
|
||||
if (message.request.verb == "PUT" && message.request.body != null) {
|
||||
when (message.request.path) {
|
||||
"/v1/address" -> {
|
||||
val address = ProvisioningAddress.ADAPTER.decode(message.request.body).address
|
||||
if (address != null) {
|
||||
provisioningUrlDeferral.complete(generateProvisioningUrl(address))
|
||||
} else {
|
||||
throw IOException("Device address is null")
|
||||
}
|
||||
}
|
||||
|
||||
"/v1/message" -> {
|
||||
val result = cipher.decrypt(RegistrationProvisionEnvelope.ADAPTER.decode(message.request.body))
|
||||
provisioningMessageDeferral.complete(result)
|
||||
}
|
||||
|
||||
else -> Log.w(TAG, "Unknown path requested")
|
||||
}
|
||||
} else {
|
||||
Log.w(TAG, "Invalid data")
|
||||
}
|
||||
}
|
||||
|
||||
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
|
||||
scope.launch {
|
||||
Log.i(TAG, "[onClosing] code: $code reason: $reason")
|
||||
|
||||
if (code != 1000) {
|
||||
Log.w(TAG, "Remote side is closing with non-normal code $code")
|
||||
webSocket.close(1000, "Remote closed with code $code")
|
||||
}
|
||||
|
||||
scope.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
|
||||
scope.launch {
|
||||
Log.w(TAG, "[onFailure] Failed", t)
|
||||
webSocket.close(1000, "Failed ${t.message}")
|
||||
|
||||
scope.cancel(CancellationException("WebSocket Failure", t))
|
||||
}
|
||||
}
|
||||
|
||||
private fun generateProvisioningUrl(deviceAddress: String): String {
|
||||
val encodedDeviceId = URLEncoder.encode(deviceAddress, "UTF-8")
|
||||
val encodedPubKey: String = URLEncoder.encode(Base64.encodeWithoutPadding(cipher.secondaryDevicePublicKey.serialize()), "UTF-8")
|
||||
return "sgnl://rereg?uuid=$encodedDeviceId&pub_key=$encodedPubKey"
|
||||
}
|
||||
|
||||
private suspend fun keepAlive(webSocket: WebSocket) {
|
||||
Log.i(TAG, "[keepAlive] Starting")
|
||||
while (true) {
|
||||
delay(30.seconds)
|
||||
Log.i(TAG, "[keepAlive] Sending...")
|
||||
|
||||
val id = System.currentTimeMillis()
|
||||
val message = WebSocketMessage(
|
||||
type = WebSocketMessage.Type.REQUEST,
|
||||
request = WebSocketRequestMessage(
|
||||
id = id,
|
||||
path = "/v1/keepalive",
|
||||
verb = "GET"
|
||||
)
|
||||
)
|
||||
|
||||
if (!webSocket.send(message.encodeByteString())) {
|
||||
Log.w(TAG, "[keepAlive] Send failed")
|
||||
} else {
|
||||
lastKeepAliveId = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun WebSocketRequestMessage.toResponse(): WebSocketMessage {
|
||||
return WebSocketMessage(
|
||||
type = WebSocketMessage.Type.RESPONSE,
|
||||
response = WebSocketResponseMessage(
|
||||
id = id,
|
||||
status = 200,
|
||||
message = "OK"
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,11 +5,14 @@
|
||||
|
||||
package org.whispersystems.signalservice.api.registration
|
||||
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey
|
||||
import org.signal.registration.proto.RegistrationProvisionMessage
|
||||
import org.whispersystems.signalservice.api.NetworkResult
|
||||
import org.whispersystems.signalservice.api.account.AccountAttributes
|
||||
import org.whispersystems.signalservice.api.account.ChangePhoneNumberRequest
|
||||
import org.whispersystems.signalservice.api.account.PniKeyDistributionRequest
|
||||
import org.whispersystems.signalservice.api.account.PreKeyCollection
|
||||
import org.whispersystems.signalservice.internal.crypto.PrimaryProvisioningCipher
|
||||
import org.whispersystems.signalservice.internal.push.BackupV2AuthCheckResponse
|
||||
import org.whispersystems.signalservice.internal.push.BackupV3AuthCheckResponse
|
||||
import org.whispersystems.signalservice.internal.push.PushServiceSocket
|
||||
@@ -142,4 +145,20 @@ class RegistrationApi(
|
||||
pushServiceSocket.distributePniKeys(requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypts and sends the [RegistrationProvisionMessage] from the current primary (old device) to the new device over
|
||||
* the provisioning web socket identified by [deviceIdentifier].
|
||||
*/
|
||||
fun sendReRegisterDeviceProvisioningMessage(
|
||||
deviceIdentifier: String,
|
||||
deviceKey: ECPublicKey,
|
||||
registrationProvisionMessage: RegistrationProvisionMessage
|
||||
): NetworkResult<Unit> {
|
||||
val cipherText = PrimaryProvisioningCipher(deviceKey).encrypt(registrationProvisionMessage)
|
||||
|
||||
return NetworkResult.fromFetch {
|
||||
pushServiceSocket.sendProvisioningMessage(deviceIdentifier, cipherText)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package org.whispersystems.signalservice.api.svr
|
||||
|
||||
import okhttp3.ConnectionSpec
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.Request
|
||||
import okhttp3.WebSocket
|
||||
@@ -9,30 +8,19 @@ import okio.ByteString
|
||||
import okio.ByteString.Companion.toByteString
|
||||
import org.signal.libsignal.attest.AttestationDataException
|
||||
import org.signal.libsignal.protocol.logging.Log
|
||||
import org.signal.libsignal.protocol.util.Pair
|
||||
import org.signal.libsignal.sgxsession.SgxCommunicationFailureException
|
||||
import org.signal.libsignal.svr2.Svr2Client
|
||||
import org.whispersystems.signalservice.api.push.TrustStore
|
||||
import org.whispersystems.signalservice.api.buildOkHttpClient
|
||||
import org.whispersystems.signalservice.api.chooseUrl
|
||||
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
|
||||
import org.whispersystems.signalservice.api.util.Tls12SocketFactory
|
||||
import org.whispersystems.signalservice.api.util.TlsProxySocketFactory
|
||||
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration
|
||||
import org.whispersystems.signalservice.internal.configuration.SignalSvr2Url
|
||||
import org.whispersystems.signalservice.internal.push.AuthCredentials
|
||||
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager
|
||||
import org.whispersystems.signalservice.internal.util.Hex
|
||||
import org.whispersystems.signalservice.internal.util.Util
|
||||
import java.io.IOException
|
||||
import java.security.KeyManagementException
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
import javax.net.ssl.SSLContext
|
||||
import javax.net.ssl.SSLSocketFactory
|
||||
import javax.net.ssl.X509TrustManager
|
||||
import kotlin.jvm.Throws
|
||||
import okhttp3.Response as OkHttpResponse
|
||||
import org.signal.svr2.proto.Request as Svr2Request
|
||||
import org.signal.svr2.proto.Response as Svr2Response
|
||||
@@ -44,8 +32,8 @@ internal class Svr2Socket(
|
||||
configuration: SignalServiceConfiguration,
|
||||
private val mrEnclave: String
|
||||
) {
|
||||
private val svr2Url: SignalSvr2Url = chooseUrl(configuration.signalSvr2Urls)
|
||||
private val okhttp: OkHttpClient = buildOkHttpClient(configuration, svr2Url)
|
||||
private val svr2Url: SignalSvr2Url = configuration.signalSvr2Urls.chooseUrl()
|
||||
private val okhttp: OkHttpClient = svr2Url.buildOkHttpClient(configuration)
|
||||
|
||||
@Throws(IOException::class)
|
||||
fun makeRequest(authorization: AuthCredentials, clientRequest: Svr2Request): Svr2Response {
|
||||
@@ -212,43 +200,5 @@ internal class Svr2Socket(
|
||||
|
||||
companion object {
|
||||
private val TAG = Svr2Socket::class.java.simpleName
|
||||
|
||||
private fun buildOkHttpClient(configuration: SignalServiceConfiguration, svr2Url: SignalSvr2Url): OkHttpClient {
|
||||
val socketFactory = createTlsSocketFactory(svr2Url.trustStore)
|
||||
val builder = OkHttpClient.Builder()
|
||||
.sslSocketFactory(Tls12SocketFactory(socketFactory.first()), socketFactory.second())
|
||||
.connectionSpecs(svr2Url.connectionSpecs.orElse(Util.immutableList(ConnectionSpec.RESTRICTED_TLS)))
|
||||
.retryOnConnectionFailure(false)
|
||||
.readTimeout(30, TimeUnit.SECONDS)
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
|
||||
for (interceptor in configuration.networkInterceptors) {
|
||||
builder.addInterceptor(interceptor)
|
||||
}
|
||||
|
||||
if (configuration.signalProxy.isPresent) {
|
||||
val proxy = configuration.signalProxy.get()
|
||||
builder.socketFactory(TlsProxySocketFactory(proxy.host, proxy.port, configuration.dns))
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
private fun createTlsSocketFactory(trustStore: TrustStore): Pair<SSLSocketFactory, X509TrustManager> {
|
||||
return try {
|
||||
val context = SSLContext.getInstance("TLS")
|
||||
val trustManagers = BlacklistingTrustManager.createFor(trustStore)
|
||||
context.init(null, trustManagers, null)
|
||||
Pair(context.socketFactory, trustManagers[0] as X509TrustManager)
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
} catch (e: KeyManagementException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun chooseUrl(urls: Array<SignalSvr2Url>): SignalSvr2Url {
|
||||
return urls[(Math.random() * urls.size).toInt()]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.signal.libsignal.protocol.kdf.HKDF;
|
||||
import org.signal.registration.proto.RegistrationProvisionEnvelope;
|
||||
import org.signal.registration.proto.RegistrationProvisionMessage;
|
||||
import org.whispersystems.signalservice.internal.push.ProvisionEnvelope;
|
||||
import org.whispersystems.signalservice.internal.push.ProvisionMessage;
|
||||
import org.whispersystems.signalservice.internal.util.Util;
|
||||
@@ -54,6 +56,24 @@ public class PrimaryProvisioningCipher {
|
||||
.encode();
|
||||
}
|
||||
|
||||
public byte[] encrypt(RegistrationProvisionMessage message) throws InvalidKeyException {
|
||||
ECKeyPair ourKeyPair = Curve.generateKeyPair();
|
||||
byte[] sharedSecret = Curve.calculateAgreement(theirPublicKey, ourKeyPair.getPrivateKey());
|
||||
byte[] derivedSecret = HKDF.deriveSecrets(sharedSecret, PROVISIONING_MESSAGE.getBytes(), 64);
|
||||
byte[][] parts = Util.split(derivedSecret, 32, 32);
|
||||
|
||||
byte[] version = { 0x00 };
|
||||
byte[] ciphertext = getCiphertext(parts[0], message.encode());
|
||||
byte[] mac = getMac(parts[1], Util.join(version, ciphertext));
|
||||
byte[] body = Util.join(version, ciphertext, mac);
|
||||
|
||||
return new RegistrationProvisionEnvelope.Builder()
|
||||
.publicKey(ByteString.of(ourKeyPair.getPublicKey().serialize()))
|
||||
.body(ByteString.of(body))
|
||||
.build()
|
||||
.encode();
|
||||
}
|
||||
|
||||
private byte[] getCiphertext(byte[] key, byte[] message) {
|
||||
try {
|
||||
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.signalservice.internal.crypto
|
||||
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.libsignal.protocol.IdentityKey
|
||||
import org.signal.libsignal.protocol.IdentityKeyPair
|
||||
import org.signal.libsignal.protocol.ecc.Curve
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey
|
||||
import org.signal.libsignal.protocol.kdf.HKDF
|
||||
import org.signal.libsignal.zkgroup.profiles.ProfileKey
|
||||
import org.signal.registration.proto.RegistrationProvisionEnvelope
|
||||
import org.signal.registration.proto.RegistrationProvisionMessage
|
||||
import org.whispersystems.signalservice.api.util.UuidUtil
|
||||
import org.whispersystems.signalservice.internal.push.ProvisionEnvelope
|
||||
import org.whispersystems.signalservice.internal.push.ProvisionMessage
|
||||
import java.security.InvalidKeyException
|
||||
import java.security.MessageDigest
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import java.util.UUID
|
||||
import javax.crypto.Cipher
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.IvParameterSpec
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
/**
|
||||
* Used to decrypt a secondary/link device provisioning message from the primary device.
|
||||
*/
|
||||
class SecondaryProvisioningCipher(private val secondaryIdentityKeyPair: IdentityKeyPair) {
|
||||
|
||||
companion object {
|
||||
private val TAG = Log.tag(SecondaryProvisioningCipher::class)
|
||||
|
||||
private const val VERSION_LENGTH = 1
|
||||
private const val IV_LENGTH = 16
|
||||
private const val MAC_LENGTH = 32
|
||||
|
||||
fun generate(identityKeyPair: IdentityKeyPair): SecondaryProvisioningCipher {
|
||||
return SecondaryProvisioningCipher(identityKeyPair)
|
||||
}
|
||||
}
|
||||
|
||||
val secondaryDevicePublicKey: IdentityKey = secondaryIdentityKeyPair.publicKey
|
||||
|
||||
fun decrypt(envelope: ProvisionEnvelope): ProvisionDecryptResult {
|
||||
val plaintext = decrypt(expectedVersion = 1, primaryEphemeralPublicKey = envelope.publicKey!!.toByteArray(), body = envelope.body!!.toByteArray())
|
||||
|
||||
if (plaintext == null) {
|
||||
Log.w(TAG, "Plaintext is null")
|
||||
return ProvisionDecryptResult.Error
|
||||
}
|
||||
|
||||
val provisioningMessage = ProvisionMessage.ADAPTER.decode(plaintext)
|
||||
|
||||
return ProvisionDecryptResult.Success(
|
||||
uuid = UuidUtil.parseOrThrow(provisioningMessage.aci),
|
||||
e164 = provisioningMessage.number!!,
|
||||
identityKeyPair = IdentityKeyPair(IdentityKey(provisioningMessage.aciIdentityKeyPublic!!.toByteArray()), Curve.decodePrivatePoint(provisioningMessage.aciIdentityKeyPrivate!!.toByteArray())),
|
||||
profileKey = ProfileKey(provisioningMessage.profileKey!!.toByteArray()),
|
||||
areReadReceiptsEnabled = provisioningMessage.readReceipts == true,
|
||||
primaryUserAgent = provisioningMessage.userAgent,
|
||||
provisioningCode = provisioningMessage.provisioningCode!!,
|
||||
provisioningVersion = provisioningMessage.provisioningVersion!!
|
||||
)
|
||||
}
|
||||
|
||||
fun decrypt(envelope: RegistrationProvisionEnvelope): RegistrationProvisionResult {
|
||||
val plaintext = decrypt(expectedVersion = 0, primaryEphemeralPublicKey = envelope.publicKey.toByteArray(), body = envelope.body.toByteArray())
|
||||
|
||||
if (plaintext == null) {
|
||||
Log.w(TAG, "Plaintext is null")
|
||||
return RegistrationProvisionResult.Error
|
||||
}
|
||||
|
||||
val provisioningMessage = RegistrationProvisionMessage.ADAPTER.decode(plaintext)
|
||||
|
||||
return RegistrationProvisionResult.Success(provisioningMessage)
|
||||
}
|
||||
|
||||
private fun decrypt(expectedVersion: Int, primaryEphemeralPublicKey: ByteArray, body: ByteArray): ByteArray? {
|
||||
val provisionMessageLength = body.size - VERSION_LENGTH - IV_LENGTH - MAC_LENGTH
|
||||
|
||||
if (provisionMessageLength <= 0) {
|
||||
Log.w(TAG, "Provisioning message length invalid")
|
||||
return null
|
||||
}
|
||||
|
||||
val version = body[0].toInt()
|
||||
if (version != expectedVersion) {
|
||||
Log.w(TAG, "Version does not match expected, expected $expectedVersion but was $version")
|
||||
return null
|
||||
}
|
||||
|
||||
val iv = body.sliceArray(1 until (1 + IV_LENGTH))
|
||||
val theirMac = body.sliceArray(body.size - MAC_LENGTH until body.size)
|
||||
val message = body.sliceArray(0 until body.size - MAC_LENGTH)
|
||||
val cipherText = body.sliceArray((1 + IV_LENGTH) until body.size - MAC_LENGTH)
|
||||
|
||||
val sharedSecret = Curve.calculateAgreement(ECPublicKey(primaryEphemeralPublicKey), secondaryIdentityKeyPair.privateKey)
|
||||
val derivedSecret: ByteArray = HKDF.deriveSecrets(sharedSecret, PrimaryProvisioningCipher.PROVISIONING_MESSAGE.toByteArray(), 64)
|
||||
|
||||
val cipherKey = derivedSecret.sliceArray(0 until 32)
|
||||
val macKey = derivedSecret.sliceArray(32 until 64)
|
||||
|
||||
val ourHmac = getMac(macKey, message)
|
||||
|
||||
if (!MessageDigest.isEqual(theirMac, ourHmac)) {
|
||||
Log.w(TAG, "Macs do not match")
|
||||
return null
|
||||
}
|
||||
|
||||
return try {
|
||||
getPlaintext(cipherKey, iv, cipherText)
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Unable to get plaintext", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private fun getMac(key: ByteArray, message: ByteArray): ByteArray? {
|
||||
return try {
|
||||
val mac = Mac.getInstance("HmacSHA256")
|
||||
mac.init(SecretKeySpec(key, "HmacSHA256"))
|
||||
mac.doFinal(message)
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
throw AssertionError(e)
|
||||
} catch (e: InvalidKeyException) {
|
||||
throw AssertionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getPlaintext(key: ByteArray, iv: ByteArray, message: ByteArray): ByteArray {
|
||||
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
|
||||
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(key, "AES"), IvParameterSpec(iv))
|
||||
return cipher.doFinal(message)
|
||||
}
|
||||
|
||||
sealed interface ProvisionDecryptResult {
|
||||
data object Error : ProvisionDecryptResult
|
||||
|
||||
data class Success(
|
||||
val uuid: UUID,
|
||||
val e164: String,
|
||||
val identityKeyPair: IdentityKeyPair,
|
||||
val profileKey: ProfileKey,
|
||||
val areReadReceiptsEnabled: Boolean,
|
||||
val primaryUserAgent: String?,
|
||||
val provisioningCode: String,
|
||||
val provisioningVersion: Int
|
||||
) : ProvisionDecryptResult
|
||||
}
|
||||
|
||||
sealed interface RegistrationProvisionResult {
|
||||
data object Error : RegistrationProvisionResult
|
||||
data class Success(val message: RegistrationProvisionMessage) : RegistrationProvisionResult
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,8 @@ package signalservice;
|
||||
option java_package = "org.whispersystems.signalservice.internal.push";
|
||||
option java_outer_classname = "ProvisioningProtos";
|
||||
|
||||
message ProvisioningUuid {
|
||||
optional string uuid = 1;
|
||||
message ProvisioningAddress {
|
||||
optional string address = 1;
|
||||
}
|
||||
|
||||
message ProvisionEnvelope {
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.signal.registration.proto";
|
||||
|
||||
message RegistrationProvisionEnvelope {
|
||||
bytes publicKey = 1;
|
||||
bytes body = 2; // Encrypted RegistrationProvisionMessage
|
||||
}
|
||||
|
||||
message RegistrationProvisionMessage {
|
||||
enum Platform {
|
||||
ANDROID = 0;
|
||||
IOS = 1;
|
||||
}
|
||||
|
||||
enum Tier {
|
||||
FREE = 0;
|
||||
PAID = 1;
|
||||
}
|
||||
|
||||
string e164 = 1;
|
||||
bytes aci = 2;
|
||||
string accountEntropyPool = 3;
|
||||
string pin = 4;
|
||||
Platform platform = 5;
|
||||
uint64 backupTimestampMs = 6;
|
||||
Tier tier = 7;
|
||||
reserved 8; // iOSDeviceTransferMessage
|
||||
}
|
||||
Reference in New Issue
Block a user