Use LibSignalChatConnection for Authenticated Socket based on Remote Config

This commit is contained in:
andrew-signal
2024-12-05 09:27:31 -05:00
committed by Greyson Parrelli
parent 9389f373c6
commit 080b79c893
6 changed files with 278 additions and 44 deletions

View File

@@ -408,6 +408,16 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
return new WebSocketFactory() { return new WebSocketFactory() {
@Override @Override
public WebSocketConnection createWebSocket() { public WebSocketConnection createWebSocket() {
if (RemoteConfig.libSignalWebSocketEnabled()) {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection(
"libsignal-auth",
network,
new DynamicCredentialsProvider(),
Stories.isFeatureEnabled(),
healthMonitor
);
} else {
return new OkHttpWebSocketConnection("normal", return new OkHttpWebSocketConnection("normal",
signalServiceConfigurationSupplier.get(), signalServiceConfigurationSupplier.get(),
Optional.of(new DynamicCredentialsProvider()), Optional.of(new DynamicCredentialsProvider()),
@@ -415,6 +425,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
healthMonitor, healthMonitor,
Stories.isFeatureEnabled()); Stories.isFeatureEnabled());
} }
}
@Override @Override
public WebSocketConnection createUnidentifiedWebSocket() { public WebSocketConnection createUnidentifiedWebSocket() {

View File

@@ -1,10 +0,0 @@
package org.whispersystems.signalservice.api.websocket;
/**
* Callbacks to provide WebSocket health information to a monitor.
*/
public interface HealthMonitor {
void onKeepAliveResponse(long sentTimestamp, boolean isIdentifiedWebSocket);
void onMessageError(int status, boolean isIdentifiedWebSocket);
}

View File

@@ -0,0 +1,10 @@
package org.whispersystems.signalservice.api.websocket
/**
* Callbacks to provide WebSocket health information to a monitor.
*/
interface HealthMonitor {
fun onKeepAliveResponse(sentTimestamp: Long, isIdentifiedWebSocket: Boolean)
fun onMessageError(status: Int, isIdentifiedWebSocket: Boolean)
}

View File

@@ -10,11 +10,14 @@ import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.schedulers.Schedulers
import io.reactivex.rxjava3.subjects.BehaviorSubject import io.reactivex.rxjava3.subjects.BehaviorSubject
import io.reactivex.rxjava3.subjects.SingleSubject import io.reactivex.rxjava3.subjects.SingleSubject
import okio.ByteString
import okio.ByteString.Companion.toByteString
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.libsignal.net.AuthenticatedChatService import org.signal.libsignal.net.AuthenticatedChatService
import org.signal.libsignal.net.ChatListener import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatServiceException import org.signal.libsignal.net.ChatServiceException
import org.signal.libsignal.net.DeviceDeregisteredException
import org.signal.libsignal.net.Network import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatService import org.signal.libsignal.net.UnauthenticatedChatService
import org.whispersystems.signalservice.api.util.CredentialsProvider import org.whispersystems.signalservice.api.util.CredentialsProvider
@@ -24,6 +27,11 @@ import org.whispersystems.signalservice.internal.util.whenComplete
import java.io.IOException import java.io.IOException
import java.time.Instant import java.time.Instant
import java.util.Optional import java.util.Optional
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.seconds import kotlin.time.Duration.Companion.seconds
@@ -50,11 +58,33 @@ class LibSignalChatConnection(
private val receiveStories: Boolean, private val receiveStories: Boolean,
private val healthMonitor: HealthMonitor private val healthMonitor: HealthMonitor
) : WebSocketConnection { ) : WebSocketConnection {
private val incomingRequestQueue = LinkedBlockingQueue<WebSocketRequestMessage>()
// One of the more nasty parts of this is that libsignal-net does not expose, nor does it ever
// intend to expose, the ID of the incoming "request" to the app layer. Instead, the app layer
// is given a callback for each message it should call when it wants to ack that message.
// The layer above this, SignalWebSocket.java, is written to handle HTTP Requests and Responses
// that have responses embedded within them.
// The goal of this stage of the project is to try and change as little as possible to isolate
// any bugs with underlying libsignal-net layer.
// So, we lie.
// We assign our own "pseudo IDs" for each incoming request in this layer, provide that ID
// up the stack to the SignalWebSocket, and then we store it. Eventually, SignalWebSocket will
// tell us to send a response for that ID, and then we use the pseudo ID as a handle to find
// the callback given to us earlier by libsignal-net, and we call that callback.
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatListener.ServerMessageAck>()
private val CHAT_SERVICE_LOCK = ReentrantLock() private val CHAT_SERVICE_LOCK = ReentrantLock()
private var chatService: ChatService? = null private var chatService: ChatService? = null
companion object { companion object {
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
const val SERVICE_ENVELOPE_REQUEST_PATH = "/api/v1/message"
const val SOCKET_EMPTY_REQUEST_VERB = "PUT"
const val SOCKET_EMPTY_REQUEST_PATH = "/api/v1/queue/empty"
const val SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY = "X-Signal-Timestamp"
private val TAG = Log.tag(LibSignalChatConnection::class.java) private val TAG = Log.tag(LibSignalChatConnection::class.java)
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
@@ -96,6 +126,22 @@ class LibSignalChatConnection(
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
val cleanupMonitor = state.subscribe { nextState ->
if (nextState == WebSocketConnectionState.DISCONNECTED) {
cleanup()
}
}
private fun cleanup() {
Log.i(TAG, "$name [cleanup]")
incomingRequestQueue.clear()
// There's a race condition here where someone has a request with an ack outstanding
// when we clear the ackSender table, but it's benign because we handle the case where
// there is no ackSender for a pseudoId gracefully in sendResponse.
ackSenderForInternalPseudoId.clear()
// There's no sense in resetting nextIncomingMessageInternalPseudoId.
}
override fun connect(): Observable<WebSocketConnectionState> { override fun connect(): Observable<WebSocketConnectionState> {
CHAT_SERVICE_LOCK.withLock { CHAT_SERVICE_LOCK.withLock {
if (chatService != null) { if (chatService != null) {
@@ -112,17 +158,28 @@ class LibSignalChatConnection(
state.onNext(WebSocketConnectionState.CONNECTED) state.onNext(WebSocketConnectionState.CONNECTED)
}, },
onFailure = { throwable -> onFailure = { throwable ->
// TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors Log.w(TAG, "$name [connect] Failure:", throwable)
Log.w(TAG, "$name Connect failed", throwable) // Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT
// request returns HTTP 403.
// The chat service currently does not return HTTP 401 on /v1/websocket.
// Thus, this currently matches the implementation in OkHttpWebSocketConnection.
if (throwable is DeviceDeregisteredException) {
state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED)
} else {
state.onNext(WebSocketConnectionState.FAILED) state.onNext(WebSocketConnectionState.FAILED)
} }
}
) )
} }
return state return state
} }
} }
override fun isDead(): Boolean = false override fun isDead(): Boolean {
CHAT_SERVICE_LOCK.withLock {
return chatService == null
}
}
override fun disconnect() { override fun disconnect() {
CHAT_SERVICE_LOCK.withLock { CHAT_SERVICE_LOCK.withLock {
@@ -150,16 +207,20 @@ class LibSignalChatConnection(
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> { override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
CHAT_SERVICE_LOCK.withLock { CHAT_SERVICE_LOCK.withLock {
if (chatService == null) { if (chatService == null) {
return Single.error(IOException("[$name] is closed!")) return Single.error(IOException("$name is closed!"))
} }
val single = SingleSubject.create<WebsocketResponse>() val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest() val internalRequest = request.toLibSignalRequest()
chatService!!.send(internalRequest) chatService!!.send(internalRequest)
.whenComplete( .whenComplete(
onSuccess = { response -> onSuccess = { response ->
when (response!!.status) { Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}")
when (response.status) {
in 400..599 -> { in 400..599 -> {
healthMonitor.onMessageError(response.status, false) healthMonitor.onMessageError(
status = response.status,
isIdentifiedWebSocket = chatService is AuthenticatedChatService
)
} }
} }
// Here success means "we received the response" even if it is reporting an error. // Here success means "we received the response" even if it is reporting an error.
@@ -167,7 +228,7 @@ class LibSignalChatConnection(
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
}, },
onFailure = { throwable -> onFailure = { throwable ->
Log.w(TAG, "$name sendRequest failed", throwable) Log.w(TAG, "$name [sendRequest] Failure:", throwable)
single.onError(throwable) single.onError(throwable)
} }
) )
@@ -185,12 +246,12 @@ class LibSignalChatConnection(
chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST) chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
.whenComplete( .whenComplete(
onSuccess = { debugResponse -> onSuccess = { debugResponse ->
Log.d(TAG, "$name Keep alive - success") Log.d(TAG, "$name [sendKeepAlive] Success")
when (debugResponse!!.response.status) { when (debugResponse!!.response.status) {
in 200..299 -> { in 200..299 -> {
healthMonitor.onKeepAliveResponse( healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value sentTimestamp = Instant.now().toEpochMilli(), // ignored. can be any value
false isIdentifiedWebSocket = chatService is AuthenticatedChatService
) )
} }
@@ -199,12 +260,12 @@ class LibSignalChatConnection(
} }
else -> { else -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${debugResponse.response.status}")
} }
} }
}, },
onFailure = { throwable -> onFailure = { throwable ->
Log.w(TAG, "$name Keep alive - failed", throwable) Log.w(TAG, "$name [sendKeepAlive] Failure:", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED) state.onNext(WebSocketConnectionState.DISCONNECTED)
} }
) )
@@ -212,28 +273,79 @@ class LibSignalChatConnection(
} }
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> { override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {
throw NotImplementedError() val incomingMessage = incomingRequestQueue.poll()
return Optional.ofNullable(incomingMessage)
} }
override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage { override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage {
throw NotImplementedError() return readRequestInternal(timeoutMillis, timeoutMillis)
} }
override fun sendResponse(response: WebSocketResponseMessage?) { private fun readRequestInternal(timeoutMillis: Long, originalTimeoutMillis: Long): WebSocketRequestMessage {
throw NotImplementedError() if (timeoutMillis < 0) {
throw TimeoutException("No message available after $originalTimeoutMillis ms")
} }
private val listener = object : ChatListener { val startTime = System.currentTimeMillis()
override fun onIncomingMessage(chat: ChatService?, envelope: ByteArray?, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) { try {
throw NotImplementedError() return incomingRequestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No message available after $originalTimeoutMillis ms")
} catch (e: InterruptedException) {
val elapsedTimeMillis = System.currentTimeMillis() - startTime
val timeoutRemainingMillis = timeoutMillis - elapsedTimeMillis
return readRequestInternal(timeoutRemainingMillis, originalTimeoutMillis)
}
} }
override fun onConnectionInterrupted(chat: ChatService?, disconnectReason: ChatServiceException?) { override fun sendResponse(response: WebSocketResponseMessage) {
if (response.status == 200 && response.message.equals("OK")) {
ackSenderForInternalPseudoId[response.id]?.send() ?: Log.w(TAG, "$name [sendResponse] Silently dropped response without available ackSend {id: ${response.id}}")
ackSenderForInternalPseudoId.remove(response.id)
Log.d(TAG, "$name [sendResponse] sent ack [${response.id}]")
} else {
// libsignal-net only supports sending {200: OK} responses
Log.w(TAG, "$name [sendResponse] Silently dropped unsupported response {status: ${response.status}, id: ${response.id}}")
ackSenderForInternalPseudoId.remove(response.id)
}
}
private val listener = LibSignalChatListener()
private inner class LibSignalChatListener : ChatListener {
override fun onIncomingMessage(chat: ChatService, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) {
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
// already in the ackSender map, if it exists.
val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement()
val incomingWebSocketRequest = WebSocketRequestMessage(
verb = SERVICE_ENVELOPE_REQUEST_VERB,
path = SERVICE_ENVELOPE_REQUEST_PATH,
body = envelope.toByteString(),
headers = listOf("$SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY: $serverDeliveryTimestamp"),
id = internalPseudoId
)
if (sendAck != null) {
ackSenderForInternalPseudoId[internalPseudoId] = sendAck
}
incomingRequestQueue.put(incomingWebSocketRequest)
}
override fun onConnectionInterrupted(chat: ChatService, disconnectReason: ChatServiceException) {
CHAT_SERVICE_LOCK.withLock { CHAT_SERVICE_LOCK.withLock {
Log.i(TAG, "connection interrupted", disconnectReason) Log.i(TAG, "$name connection interrupted", disconnectReason)
state.onNext(WebSocketConnectionState.DISCONNECTED)
chatService = null chatService = null
} state.onNext(WebSocketConnectionState.DISCONNECTED)
}
}
override fun onQueueEmpty(chat: ChatService) {
val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement()
val queueEmptyRequest = WebSocketRequestMessage(
verb = SOCKET_EMPTY_REQUEST_VERB,
path = SOCKET_EMPTY_REQUEST_PATH,
body = ByteString.EMPTY,
headers = listOf(),
id = internalPseudoId
)
incomingRequestQueue.put(queueEmptyRequest)
} }
} }
} }

View File

@@ -35,5 +35,5 @@ interface WebSocketConnection {
fun readRequest(timeoutMillis: Long): WebSocketRequestMessage fun readRequest(timeoutMillis: Long): WebSocketRequestMessage
@Throws(IOException::class) @Throws(IOException::class)
fun sendResponse(response: WebSocketResponseMessage?) fun sendResponse(response: WebSocketResponseMessage)
} }

View File

@@ -6,6 +6,9 @@ import io.mockk.mockk
import io.mockk.mockkStatic import io.mockk.mockkStatic
import io.mockk.verify import io.mockk.verify
import io.reactivex.rxjava3.observers.TestObserver import io.reactivex.rxjava3.observers.TestObserver
import okio.ByteString.Companion.toByteString
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.signal.libsignal.internal.CompletableFuture import org.signal.libsignal.internal.CompletableFuture
@@ -21,6 +24,7 @@ import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse
@@ -286,6 +290,39 @@ class LibSignalChatConnectionTest {
} }
} }
@Test
fun connectionInterruptedTest() {
val disconnectReason = ChatServiceException("simulated interrupt")
val connectLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
chatListener!!.onConnectionInterrupted(chatService, disconnectReason)
observer.assertNotComplete()
observer.assertValues(
// We start in the connected state
WebSocketConnectionState.CONNECTED,
// Disconnects as a result of the connection interrupted event
WebSocketConnectionState.DISCONNECTED
)
verify(exactly = 0) {
healthMonitor.onKeepAliveResponse(any(), any())
healthMonitor.onMessageError(any(), any())
}
}
@Test @Test
fun connectionInterrupted() { fun connectionInterrupted() {
val disconnectReason = ChatServiceException("simulated interrupt") val disconnectReason = ChatServiceException("simulated interrupt")
@@ -319,6 +356,80 @@ class LibSignalChatConnectionTest {
} }
} }
@Test
fun incomingRequests() {
val connectLatch = CountDownLatch(1)
val asyncMessageReadLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
var timedOut = false
try {
connection.readRequest(10)
} catch (e: TimeoutException) {
timedOut = true
}
assert(timedOut)
val envelopeA = "msgA".toByteArray()
val envelopeB = "msgB".toByteArray()
val envelopeC = "msgC".toByteArray()
fun assertRequestWithEnvelope(request: WebSocketRequestMessage, envelope: ByteArray) {
assertEquals("PUT", request.verb)
assertEquals("/api/v1/message", request.path)
assertEquals(envelope.toByteString(), request.body!!)
connection.sendResponse(
WebSocketResponseMessage(
request.id,
200,
"OK"
)
)
}
fun assertQueueEmptyRequest(request: WebSocketRequestMessage) {
assertEquals("PUT", request.verb)
assertEquals("/api/v1/queue/empty", request.path)
connection.sendResponse(
WebSocketResponseMessage(
request.id,
200,
"OK"
)
)
}
executor.submit {
assertRequestWithEnvelope(connection.readRequest(10), envelopeA)
asyncMessageReadLatch.countDown()
}
chatListener!!.onIncomingMessage(chatService, envelopeA, 0, null)
asyncMessageReadLatch.await(100, TimeUnit.MILLISECONDS)
chatListener!!.onIncomingMessage(chatService, envelopeB, 0, null)
assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeB)
chatListener!!.onQueueEmpty(chatService)
assertQueueEmptyRequest(connection.readRequestIfAvailable().get())
chatListener!!.onIncomingMessage(chatService, envelopeC, 0, null)
assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeC)
assertTrue(connection.readRequestIfAvailable().isEmpty)
}
private fun <T> delay(action: ((CompletableFuture<T>) -> Unit)): CompletableFuture<T> { private fun <T> delay(action: ((CompletableFuture<T>) -> Unit)): CompletableFuture<T> {
val future = CompletableFuture<T>() val future = CompletableFuture<T>()
executor.submit { executor.submit {