mirror of
https://github.com/signalapp/Signal-Android.git
synced 2026-05-08 17:29:02 +01:00
Harmonize libsignal-net behavior to match existing websocket implementation.
This commit is contained in:
committed by
Greyson Parrelli
parent
be90efa23d
commit
c95073e5dd
+131
-25
@@ -29,6 +29,7 @@ import java.io.IOException
|
|||||||
import java.time.Instant
|
import java.time.Instant
|
||||||
import java.util.Optional
|
import java.util.Optional
|
||||||
import java.util.concurrent.ConcurrentHashMap
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
import java.util.concurrent.Executors
|
||||||
import java.util.concurrent.LinkedBlockingQueue
|
import java.util.concurrent.LinkedBlockingQueue
|
||||||
import java.util.concurrent.TimeUnit
|
import java.util.concurrent.TimeUnit
|
||||||
import java.util.concurrent.TimeoutException
|
import java.util.concurrent.TimeoutException
|
||||||
@@ -75,8 +76,16 @@ class LibSignalChatConnection(
|
|||||||
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
|
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
|
||||||
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatConnectionListener.ServerMessageAck>()
|
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatConnectionListener.ServerMessageAck>()
|
||||||
|
|
||||||
|
// CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, and
|
||||||
|
// chatConnectionFuture
|
||||||
|
// stateChangedOrMessageReceivedCondition: derived from CHAT_SERVICE_LOCK, used by readRequest(),
|
||||||
|
// exists to emulate idiosyncratic behavior of OkHttpWebSocketConnection for readRequest()
|
||||||
|
// chatConnection: Set only when state == CONNECTED
|
||||||
|
// chatConnectionFuture: Set only when state == CONNECTING
|
||||||
private val CHAT_SERVICE_LOCK = ReentrantLock()
|
private val CHAT_SERVICE_LOCK = ReentrantLock()
|
||||||
|
private val stateChangedOrMessageReceivedCondition = CHAT_SERVICE_LOCK.newCondition()
|
||||||
private var chatConnection: ChatConnection? = null
|
private var chatConnection: ChatConnection? = null
|
||||||
|
private var chatConnectionFuture: CompletableFuture<out ChatConnection>? = null
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
|
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
|
||||||
@@ -126,10 +135,14 @@ class LibSignalChatConnection(
|
|||||||
|
|
||||||
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
|
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
|
||||||
|
|
||||||
val cleanupMonitor = state.subscribe { nextState ->
|
val stateMonitor = state.subscribe { nextState ->
|
||||||
if (nextState == WebSocketConnectionState.DISCONNECTED) {
|
if (nextState == WebSocketConnectionState.DISCONNECTED) {
|
||||||
cleanup()
|
cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
stateChangedOrMessageReceivedCondition.signalAll()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun cleanup() {
|
private fun cleanup() {
|
||||||
@@ -155,13 +168,15 @@ class LibSignalChatConnection(
|
|||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
Log.i(TAG, "$name Connecting...")
|
Log.i(TAG, "$name Connecting...")
|
||||||
val chatConnectionFuture: CompletableFuture<out ChatConnection> = if (credentialsProvider == null) {
|
chatConnectionFuture = if (credentialsProvider == null) {
|
||||||
network.connectUnauthChat(listener)
|
network.connectUnauthChat(listener)
|
||||||
} else {
|
} else {
|
||||||
network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener)
|
network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener)
|
||||||
}
|
}
|
||||||
state.onNext(WebSocketConnectionState.CONNECTING)
|
state.onNext(WebSocketConnectionState.CONNECTING)
|
||||||
chatConnectionFuture.whenComplete(
|
// We are now in the CONNECTING state, so chatConnectionFuture should be set, and there is no
|
||||||
|
// nullability concern here.
|
||||||
|
chatConnectionFuture!!.whenComplete(
|
||||||
onSuccess = { connection ->
|
onSuccess = { connection ->
|
||||||
CHAT_SERVICE_LOCK.withLock {
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
if (state.value == WebSocketConnectionState.CONNECTING) {
|
if (state.value == WebSocketConnectionState.CONNECTING) {
|
||||||
@@ -218,13 +233,14 @@ class LibSignalChatConnection(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This avoids a crash when we get a connection lost event during a connection attempt and try
|
// OkHttpWebSocketConnection will terminate a connection if disconnect() is called while
|
||||||
// to cancel a connection that has not yet been fully established.
|
// the connection itself is still CONNECTING, so we carry forward that behavior here.
|
||||||
// TODO [andrew]: Figure out if this is the right long term behavior.
|
|
||||||
if (state.value == WebSocketConnectionState.CONNECTING) {
|
if (state.value == WebSocketConnectionState.CONNECTING) {
|
||||||
// The right way to do this is to cancel the CompletableFuture returned by connectChat()
|
// The right way to do this is to cancel the CompletableFuture returned by connectChat().
|
||||||
|
// This will terminate forward progress on the connection attempt, and mostly closely match
|
||||||
|
// what OkHttpWebSocketConnection does.
|
||||||
// Unfortunately, libsignal's CompletableFuture does not yet support cancellation.
|
// Unfortunately, libsignal's CompletableFuture does not yet support cancellation.
|
||||||
// Instead, we set a flag to disconnect() as soon as the connection completes.
|
// So, instead, we set a flag to disconnect() as soon as the connection completes.
|
||||||
// TODO [andrew]: Add cancellation support to CompletableFuture and use it here
|
// TODO [andrew]: Add cancellation support to CompletableFuture and use it here
|
||||||
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
||||||
return
|
return
|
||||||
@@ -289,11 +305,33 @@ class LibSignalChatConnection(
|
|||||||
|
|
||||||
override fun sendKeepAlive() {
|
override fun sendKeepAlive() {
|
||||||
CHAT_SERVICE_LOCK.withLock {
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
// This is a stronger check than isDead, to handle the case where chatConnection may be null
|
if (isDead()) {
|
||||||
// because we are still connecting.
|
// This matches the behavior of OkHttpWebSocketConnection, where if a keep alive is sent
|
||||||
// TODO [andrew]: Decide if this is the right behavior long term, or if we want to queue these
|
// while we are not connected, we simply drop the keep alive.
|
||||||
// like we plan to queue other requests long term.
|
return
|
||||||
if (state.value != WebSocketConnectionState.CONNECTED) {
|
}
|
||||||
|
|
||||||
|
if (state.value == WebSocketConnectionState.CONNECTING) {
|
||||||
|
// Handle the special case where we are connecting, so we cannot (yet) send the keep-alive.
|
||||||
|
// OkHttpWebSocketConnection buffers the keep alive request, and sends it when the connection
|
||||||
|
// completes.
|
||||||
|
// We just checked that we are in the CONNECTING state, and we hold the CHAT_SERVICE_LOCK, so
|
||||||
|
// our state cannot change, thus there is no nullability concern with chatConnectionFuture.
|
||||||
|
Log.i(TAG, "$name Buffering keep alive to send after connection establishment")
|
||||||
|
chatConnectionFuture!!.whenComplete(
|
||||||
|
onSuccess = {
|
||||||
|
Log.i(TAG, "$name Sending buffered keep alive")
|
||||||
|
// sendKeepAlive() will internally grab the CHAT_SERVICE_LOCK and check to ensure we are
|
||||||
|
// still in the CONNECTED state when this callback runs, so we do not need to worry about
|
||||||
|
// any state here.
|
||||||
|
sendKeepAlive()
|
||||||
|
},
|
||||||
|
onFailure = {
|
||||||
|
// OkHttpWebSocketConnection did not report a keep alive failure to the healthMonitor
|
||||||
|
// when a buffered keep alive failed to send because the underlying connection
|
||||||
|
// establishment failed, so neither do we.
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,22 +370,76 @@ class LibSignalChatConnection(
|
|||||||
return Optional.ofNullable(incomingMessage)
|
return Optional.ofNullable(incomingMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Blocks until a request is received from the underlying ChatConnection.
|
||||||
|
*
|
||||||
|
* This method’s behavior is critical for message retrieval and must adhere to the following:
|
||||||
|
*
|
||||||
|
* - Blocks until a request is available.
|
||||||
|
* - If no message is received within the specified [timeoutMillis], a [TimeoutException] is thrown.
|
||||||
|
* - If the ChatConnection becomes disconnected while waiting, an [IOException] is thrown immediately.
|
||||||
|
* - If invoked when the ChatConnection is dead (i.e. disconnected or failed), an [IOException] is thrown.
|
||||||
|
* - If the ChatConnection is still in the process of connecting, the method will block until the connection
|
||||||
|
* is established and a message is received. The time spent waiting for the connection is counted towards
|
||||||
|
* the [timeoutMillis]. Should the connection attempt eventually fail, an [IOException] is thrown promptly.
|
||||||
|
*
|
||||||
|
* **Note:** This method is used by the MessageRetrievalThread to receive updates about the connection state
|
||||||
|
* from other threads. Any delay in throwing exceptions could block this thread, resulting in prolonged holding
|
||||||
|
* of the Foreground Service and wake lock, which may lead to adverse behavior by the operating system.
|
||||||
|
*
|
||||||
|
* @param timeoutMillis the maximum time in milliseconds to wait for a request.
|
||||||
|
* @return the received [WebSocketRequestMessage].
|
||||||
|
* @throws TimeoutException if the timeout elapses without receiving a message.
|
||||||
|
* @throws IOException if the ChatConnection becomes disconnected, is dead, or if the connection attempt fails.
|
||||||
|
*/
|
||||||
override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage {
|
override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage {
|
||||||
return readRequestInternal(timeoutMillis, timeoutMillis)
|
if (timeoutMillis <= 0) {
|
||||||
}
|
// OkHttpWebSocketConnection throws a TimeoutException in this case, so we do too.
|
||||||
|
throw TimeoutException("Invalid timeoutMillis")
|
||||||
private fun readRequestInternal(timeoutMillis: Long, originalTimeoutMillis: Long): WebSocketRequestMessage {
|
|
||||||
if (timeoutMillis < 0) {
|
|
||||||
throw TimeoutException("No message available after $originalTimeoutMillis ms")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val startTime = System.currentTimeMillis()
|
val startTime = System.currentTimeMillis()
|
||||||
try {
|
|
||||||
return incomingRequestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No message available after $originalTimeoutMillis ms")
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
} catch (e: InterruptedException) {
|
if (isDead()) {
|
||||||
val elapsedTimeMillis = System.currentTimeMillis() - startTime
|
// Matches behavior of OkHttpWebSocketConnection
|
||||||
val timeoutRemainingMillis = timeoutMillis - elapsedTimeMillis
|
throw IOException("Connection closed!")
|
||||||
return readRequestInternal(timeoutRemainingMillis, originalTimeoutMillis)
|
}
|
||||||
|
|
||||||
|
var remainingTimeoutMillis = timeoutMillis
|
||||||
|
|
||||||
|
fun couldGetRequest(): Boolean {
|
||||||
|
return state.value == WebSocketConnectionState.CONNECTED || state.value == WebSocketConnectionState.CONNECTING
|
||||||
|
}
|
||||||
|
|
||||||
|
while (couldGetRequest() && incomingRequestQueue.isEmpty()) {
|
||||||
|
if (remainingTimeoutMillis <= 0) {
|
||||||
|
throw TimeoutException("Timeout exceeded after $timeoutMillis ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// This condition variable is created from CHAT_SERVICE_LOCK, and thus releases CHAT_SERVICE_LOCK
|
||||||
|
// while we await the condition variable.
|
||||||
|
stateChangedOrMessageReceivedCondition.await(remainingTimeoutMillis, TimeUnit.MILLISECONDS)
|
||||||
|
} catch (_: InterruptedException) { }
|
||||||
|
val elapsedTimeMillis = System.currentTimeMillis() - startTime
|
||||||
|
remainingTimeoutMillis = timeoutMillis - elapsedTimeMillis
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!incomingRequestQueue.isEmpty()) {
|
||||||
|
return incomingRequestQueue.poll()
|
||||||
|
} else if (!couldGetRequest()) {
|
||||||
|
throw IOException("Connection closed!")
|
||||||
|
} else {
|
||||||
|
// This happens if we somehow break out of the loop but incomingRequestQueue is empty
|
||||||
|
// and we were still in a state where we could get a request.
|
||||||
|
// This *could* theoretically happen if two different threads call readRequest at the same time,
|
||||||
|
// this thread is the one that loses the race to take the request off the queue.
|
||||||
|
// (NB: I don't think this is a practical issue, because readRequest() should only be called from
|
||||||
|
// the MessageRetrievalThread, but OkHttpWebSocketConnection treated this as a TimeoutException, so
|
||||||
|
// this class also dutifully treats it as a TimeoutException.)
|
||||||
|
throw TimeoutException("Incoming request queue was empty!")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,6 +458,8 @@ class LibSignalChatConnection(
|
|||||||
private val listener = LibSignalChatListener()
|
private val listener = LibSignalChatListener()
|
||||||
|
|
||||||
private inner class LibSignalChatListener : ChatConnectionListener {
|
private inner class LibSignalChatListener : ChatConnectionListener {
|
||||||
|
private val executor = Executors.newSingleThreadExecutor()
|
||||||
|
|
||||||
override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) {
|
override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) {
|
||||||
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
|
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
|
||||||
// already in the ackSender map, if it exists.
|
// already in the ackSender map, if it exists.
|
||||||
@@ -381,6 +475,12 @@ class LibSignalChatConnection(
|
|||||||
ackSenderForInternalPseudoId[internalPseudoId] = sendAck
|
ackSenderForInternalPseudoId[internalPseudoId] = sendAck
|
||||||
}
|
}
|
||||||
incomingRequestQueue.put(incomingWebSocketRequest)
|
incomingRequestQueue.put(incomingWebSocketRequest)
|
||||||
|
// Try to not block the ChatConnectionListener callback context if we can help it.
|
||||||
|
executor.submit {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
stateChangedOrMessageReceivedCondition.signalAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException?) {
|
override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException?) {
|
||||||
@@ -402,6 +502,12 @@ class LibSignalChatConnection(
|
|||||||
id = internalPseudoId
|
id = internalPseudoId
|
||||||
)
|
)
|
||||||
incomingRequestQueue.put(queueEmptyRequest)
|
incomingRequestQueue.put(queueEmptyRequest)
|
||||||
|
// Try to not block the ChatConnectionListener callback context if we can help it.
|
||||||
|
executor.submit {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
stateChangedOrMessageReceivedCondition.signalAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+30
-9
@@ -8,6 +8,7 @@ import io.mockk.verify
|
|||||||
import io.reactivex.rxjava3.observers.TestObserver
|
import io.reactivex.rxjava3.observers.TestObserver
|
||||||
import okio.ByteString.Companion.toByteString
|
import okio.ByteString.Companion.toByteString
|
||||||
import org.junit.Assert.assertEquals
|
import org.junit.Assert.assertEquals
|
||||||
|
import org.junit.Assert.assertThrows
|
||||||
import org.junit.Assert.assertTrue
|
import org.junit.Assert.assertTrue
|
||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
@@ -19,6 +20,7 @@ import org.signal.libsignal.net.Network
|
|||||||
import org.signal.libsignal.net.UnauthenticatedChatConnection
|
import org.signal.libsignal.net.UnauthenticatedChatConnection
|
||||||
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
||||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
||||||
|
import java.io.IOException
|
||||||
import java.util.concurrent.CountDownLatch
|
import java.util.concurrent.CountDownLatch
|
||||||
import java.util.concurrent.ExecutorService
|
import java.util.concurrent.ExecutorService
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
@@ -307,6 +309,34 @@ class LibSignalChatConnectionTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If readRequest() does not throw when the underlying connection disconnects, this
|
||||||
|
// causes the app to get stuck in a "fetching new messages" state.
|
||||||
|
@Test
|
||||||
|
fun regressionTestReadRequestThrowsOnDisconnect() {
|
||||||
|
setupConnectedConnection()
|
||||||
|
|
||||||
|
executor.submit {
|
||||||
|
Thread.sleep(100)
|
||||||
|
chatConnection.disconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThrows(IOException::class.java) {
|
||||||
|
connection.readRequest(1000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 20)
|
||||||
|
fun readRequestDoesTimeOut() {
|
||||||
|
setupConnectedConnection()
|
||||||
|
|
||||||
|
val observer = TestObserver<WebSocketConnectionState>()
|
||||||
|
connection.state.subscribe(observer)
|
||||||
|
|
||||||
|
assertThrows(TimeoutException::class.java) {
|
||||||
|
connection.readRequest(10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Test reading incoming requests from the queue.
|
// Test reading incoming requests from the queue.
|
||||||
// We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection.
|
// We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection.
|
||||||
@Test
|
@Test
|
||||||
@@ -316,15 +346,6 @@ class LibSignalChatConnectionTest {
|
|||||||
val observer = TestObserver<WebSocketConnectionState>()
|
val observer = TestObserver<WebSocketConnectionState>()
|
||||||
connection.state.subscribe(observer)
|
connection.state.subscribe(observer)
|
||||||
|
|
||||||
// Confirm that readRequest times out if there's no message.
|
|
||||||
var timedOut = false
|
|
||||||
try {
|
|
||||||
connection.readRequest(10)
|
|
||||||
} catch (e: TimeoutException) {
|
|
||||||
timedOut = true
|
|
||||||
}
|
|
||||||
assertTrue(timedOut)
|
|
||||||
|
|
||||||
// We'll now simulate incoming messages
|
// We'll now simulate incoming messages
|
||||||
val envelopeA = "msgA".toByteArray()
|
val envelopeA = "msgA".toByteArray()
|
||||||
val envelopeB = "msgB".toByteArray()
|
val envelopeB = "msgB".toByteArray()
|
||||||
|
|||||||
Reference in New Issue
Block a user