Harmonize libsignal-net behavior to match existing websocket implementation.

This commit is contained in:
andrew-signal
2025-02-20 14:11:25 -05:00
committed by Greyson Parrelli
parent be90efa23d
commit c95073e5dd
2 changed files with 161 additions and 34 deletions

View File

@@ -29,6 +29,7 @@ import java.io.IOException
import java.time.Instant
import java.util.Optional
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
@@ -75,8 +76,16 @@ class LibSignalChatConnection(
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatConnectionListener.ServerMessageAck>()
// CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, and
// chatConnectionFuture
// stateChangedOrMessageReceivedCondition: derived from CHAT_SERVICE_LOCK, used by readRequest(),
// exists to emulate idiosyncratic behavior of OkHttpWebSocketConnection for readRequest()
// chatConnection: Set only when state == CONNECTED
// chatConnectionFuture: Set only when state == CONNECTING
private val CHAT_SERVICE_LOCK = ReentrantLock()
private val stateChangedOrMessageReceivedCondition = CHAT_SERVICE_LOCK.newCondition()
private var chatConnection: ChatConnection? = null
private var chatConnectionFuture: CompletableFuture<out ChatConnection>? = null
companion object {
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
@@ -126,10 +135,14 @@ class LibSignalChatConnection(
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
val cleanupMonitor = state.subscribe { nextState ->
val stateMonitor = state.subscribe { nextState ->
if (nextState == WebSocketConnectionState.DISCONNECTED) {
cleanup()
}
CHAT_SERVICE_LOCK.withLock {
stateChangedOrMessageReceivedCondition.signalAll()
}
}
private fun cleanup() {
@@ -155,13 +168,15 @@ class LibSignalChatConnection(
return state
}
Log.i(TAG, "$name Connecting...")
val chatConnectionFuture: CompletableFuture<out ChatConnection> = if (credentialsProvider == null) {
chatConnectionFuture = if (credentialsProvider == null) {
network.connectUnauthChat(listener)
} else {
network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener)
}
state.onNext(WebSocketConnectionState.CONNECTING)
chatConnectionFuture.whenComplete(
// We are now in the CONNECTING state, so chatConnectionFuture should be set, and there is no
// nullability concern here.
chatConnectionFuture!!.whenComplete(
onSuccess = { connection ->
CHAT_SERVICE_LOCK.withLock {
if (state.value == WebSocketConnectionState.CONNECTING) {
@@ -218,13 +233,14 @@ class LibSignalChatConnection(
return
}
// This avoids a crash when we get a connection lost event during a connection attempt and try
// to cancel a connection that has not yet been fully established.
// TODO [andrew]: Figure out if this is the right long term behavior.
// OkHttpWebSocketConnection will terminate a connection if disconnect() is called while
// the connection itself is still CONNECTING, so we carry forward that behavior here.
if (state.value == WebSocketConnectionState.CONNECTING) {
// The right way to do this is to cancel the CompletableFuture returned by connectChat()
// The right way to do this is to cancel the CompletableFuture returned by connectChat().
// This will terminate forward progress on the connection attempt, and mostly closely match
// what OkHttpWebSocketConnection does.
// Unfortunately, libsignal's CompletableFuture does not yet support cancellation.
// Instead, we set a flag to disconnect() as soon as the connection completes.
// So, instead, we set a flag to disconnect() as soon as the connection completes.
// TODO [andrew]: Add cancellation support to CompletableFuture and use it here
state.onNext(WebSocketConnectionState.DISCONNECTING)
return
@@ -289,11 +305,33 @@ class LibSignalChatConnection(
override fun sendKeepAlive() {
CHAT_SERVICE_LOCK.withLock {
// This is a stronger check than isDead, to handle the case where chatConnection may be null
// because we are still connecting.
// TODO [andrew]: Decide if this is the right behavior long term, or if we want to queue these
// like we plan to queue other requests long term.
if (state.value != WebSocketConnectionState.CONNECTED) {
if (isDead()) {
// This matches the behavior of OkHttpWebSocketConnection, where if a keep alive is sent
// while we are not connected, we simply drop the keep alive.
return
}
if (state.value == WebSocketConnectionState.CONNECTING) {
// Handle the special case where we are connecting, so we cannot (yet) send the keep-alive.
// OkHttpWebSocketConnection buffers the keep alive request, and sends it when the connection
// completes.
// We just checked that we are in the CONNECTING state, and we hold the CHAT_SERVICE_LOCK, so
// our state cannot change, thus there is no nullability concern with chatConnectionFuture.
Log.i(TAG, "$name Buffering keep alive to send after connection establishment")
chatConnectionFuture!!.whenComplete(
onSuccess = {
Log.i(TAG, "$name Sending buffered keep alive")
// sendKeepAlive() will internally grab the CHAT_SERVICE_LOCK and check to ensure we are
// still in the CONNECTED state when this callback runs, so we do not need to worry about
// any state here.
sendKeepAlive()
},
onFailure = {
// OkHttpWebSocketConnection did not report a keep alive failure to the healthMonitor
// when a buffered keep alive failed to send because the underlying connection
// establishment failed, so neither do we.
}
)
return
}
@@ -332,22 +370,76 @@ class LibSignalChatConnection(
return Optional.ofNullable(incomingMessage)
}
/**
* Blocks until a request is received from the underlying ChatConnection.
*
* This methods behavior is critical for message retrieval and must adhere to the following:
*
* - Blocks until a request is available.
* - If no message is received within the specified [timeoutMillis], a [TimeoutException] is thrown.
* - If the ChatConnection becomes disconnected while waiting, an [IOException] is thrown immediately.
* - If invoked when the ChatConnection is dead (i.e. disconnected or failed), an [IOException] is thrown.
* - If the ChatConnection is still in the process of connecting, the method will block until the connection
* is established and a message is received. The time spent waiting for the connection is counted towards
* the [timeoutMillis]. Should the connection attempt eventually fail, an [IOException] is thrown promptly.
*
* **Note:** This method is used by the MessageRetrievalThread to receive updates about the connection state
* from other threads. Any delay in throwing exceptions could block this thread, resulting in prolonged holding
* of the Foreground Service and wake lock, which may lead to adverse behavior by the operating system.
*
* @param timeoutMillis the maximum time in milliseconds to wait for a request.
* @return the received [WebSocketRequestMessage].
* @throws TimeoutException if the timeout elapses without receiving a message.
* @throws IOException if the ChatConnection becomes disconnected, is dead, or if the connection attempt fails.
*/
override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage {
return readRequestInternal(timeoutMillis, timeoutMillis)
}
private fun readRequestInternal(timeoutMillis: Long, originalTimeoutMillis: Long): WebSocketRequestMessage {
if (timeoutMillis < 0) {
throw TimeoutException("No message available after $originalTimeoutMillis ms")
if (timeoutMillis <= 0) {
// OkHttpWebSocketConnection throws a TimeoutException in this case, so we do too.
throw TimeoutException("Invalid timeoutMillis")
}
val startTime = System.currentTimeMillis()
try {
return incomingRequestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No message available after $originalTimeoutMillis ms")
} catch (e: InterruptedException) {
val elapsedTimeMillis = System.currentTimeMillis() - startTime
val timeoutRemainingMillis = timeoutMillis - elapsedTimeMillis
return readRequestInternal(timeoutRemainingMillis, originalTimeoutMillis)
CHAT_SERVICE_LOCK.withLock {
if (isDead()) {
// Matches behavior of OkHttpWebSocketConnection
throw IOException("Connection closed!")
}
var remainingTimeoutMillis = timeoutMillis
fun couldGetRequest(): Boolean {
return state.value == WebSocketConnectionState.CONNECTED || state.value == WebSocketConnectionState.CONNECTING
}
while (couldGetRequest() && incomingRequestQueue.isEmpty()) {
if (remainingTimeoutMillis <= 0) {
throw TimeoutException("Timeout exceeded after $timeoutMillis ms")
}
try {
// This condition variable is created from CHAT_SERVICE_LOCK, and thus releases CHAT_SERVICE_LOCK
// while we await the condition variable.
stateChangedOrMessageReceivedCondition.await(remainingTimeoutMillis, TimeUnit.MILLISECONDS)
} catch (_: InterruptedException) { }
val elapsedTimeMillis = System.currentTimeMillis() - startTime
remainingTimeoutMillis = timeoutMillis - elapsedTimeMillis
}
if (!incomingRequestQueue.isEmpty()) {
return incomingRequestQueue.poll()
} else if (!couldGetRequest()) {
throw IOException("Connection closed!")
} else {
// This happens if we somehow break out of the loop but incomingRequestQueue is empty
// and we were still in a state where we could get a request.
// This *could* theoretically happen if two different threads call readRequest at the same time,
// this thread is the one that loses the race to take the request off the queue.
// (NB: I don't think this is a practical issue, because readRequest() should only be called from
// the MessageRetrievalThread, but OkHttpWebSocketConnection treated this as a TimeoutException, so
// this class also dutifully treats it as a TimeoutException.)
throw TimeoutException("Incoming request queue was empty!")
}
}
}
@@ -366,6 +458,8 @@ class LibSignalChatConnection(
private val listener = LibSignalChatListener()
private inner class LibSignalChatListener : ChatConnectionListener {
private val executor = Executors.newSingleThreadExecutor()
override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) {
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
// already in the ackSender map, if it exists.
@@ -381,6 +475,12 @@ class LibSignalChatConnection(
ackSenderForInternalPseudoId[internalPseudoId] = sendAck
}
incomingRequestQueue.put(incomingWebSocketRequest)
// Try to not block the ChatConnectionListener callback context if we can help it.
executor.submit {
CHAT_SERVICE_LOCK.withLock {
stateChangedOrMessageReceivedCondition.signalAll()
}
}
}
override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException?) {
@@ -402,6 +502,12 @@ class LibSignalChatConnection(
id = internalPseudoId
)
incomingRequestQueue.put(queueEmptyRequest)
// Try to not block the ChatConnectionListener callback context if we can help it.
executor.submit {
CHAT_SERVICE_LOCK.withLock {
stateChangedOrMessageReceivedCondition.signalAll()
}
}
}
}
}

View File

@@ -8,6 +8,7 @@ import io.mockk.verify
import io.reactivex.rxjava3.observers.TestObserver
import okio.ByteString.Companion.toByteString
import org.junit.Assert.assertEquals
import org.junit.Assert.assertThrows
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
@@ -19,6 +20,7 @@ import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatConnection
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.io.IOException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
@@ -307,6 +309,34 @@ class LibSignalChatConnectionTest {
}
}
// If readRequest() does not throw when the underlying connection disconnects, this
// causes the app to get stuck in a "fetching new messages" state.
@Test
fun regressionTestReadRequestThrowsOnDisconnect() {
setupConnectedConnection()
executor.submit {
Thread.sleep(100)
chatConnection.disconnect()
}
assertThrows(IOException::class.java) {
connection.readRequest(1000)
}
}
@Test(timeout = 20)
fun readRequestDoesTimeOut() {
setupConnectedConnection()
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
assertThrows(TimeoutException::class.java) {
connection.readRequest(10)
}
}
// Test reading incoming requests from the queue.
// We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection.
@Test
@@ -316,15 +346,6 @@ class LibSignalChatConnectionTest {
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
// Confirm that readRequest times out if there's no message.
var timedOut = false
try {
connection.readRequest(10)
} catch (e: TimeoutException) {
timedOut = true
}
assertTrue(timedOut)
// We'll now simulate incoming messages
val envelopeA = "msgA".toByteArray()
val envelopeB = "msgB".toByteArray()