Update LibSignalChatConnection to use new ChatConnection API rather than ChatService

This commit is contained in:
andrew-signal
2025-02-04 14:34:07 -05:00
committed by Greyson Parrelli
parent fe44789d88
commit 2186e2bf92
5 changed files with 218 additions and 253 deletions

View File

@@ -438,7 +438,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled(),
LibSignalNetworkExtensions.createChatService(libSignalNetworkSupplier.get(), null, Stories.isFeatureEnabled(), null),
libSignalNetworkSupplier.get(),
shadowPercentage,
bridge
);

View File

@@ -12,14 +12,16 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject
import io.reactivex.rxjava3.subjects.SingleSubject
import okio.ByteString
import okio.ByteString.Companion.toByteString
import okio.withLock
import org.signal.core.util.logging.Log
import org.signal.libsignal.net.AuthenticatedChatService
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.AuthenticatedChatConnection
import org.signal.libsignal.net.ChatConnection
import org.signal.libsignal.net.ChatConnectionListener
import org.signal.libsignal.net.ChatServiceException
import org.signal.libsignal.net.DeviceDeregisteredException
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatService
import org.signal.libsignal.net.UnauthenticatedChatConnection
import org.whispersystems.signalservice.api.util.CredentialsProvider
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
@@ -35,21 +37,20 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest
import org.signal.libsignal.net.ChatConnection.Response as LibSignalResponse
/**
* Implements the WebSocketConnection interface via libsignal-net
*
* Notable implementation choices:
* - [chatService] contains both the authenticated and unauthenticated connections,
* which one to use for [sendRequest]/[sendResponse] is based on [isAuthenticated].
* - keep-alive requests always use the [org.signal.libsignal.net.ChatService.unauthenticatedSendAndDebug]
* API, and log the debug info on success.
* - regular sends use [org.signal.libsignal.net.ChatService.unauthenticatedSend] and don't create any overhead.
* - [chatConnection] contains either an authenticated or an unauthenticated connections
* - keep-alive requests are sent on both authenticated and unauthenticated connections, mirroring the existing OkHttp behavior
* - [org.whispersystems.signalservice.api.websocket.WebSocketConnectionState] reporting is implemented
* as close as possible to the original implementation in
* [org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection].
* - we expose fake "psuedo IDs" for incoming requests so the layer on top of ours can work with IDs, just
* like with the old OkHttp implementation, and internally we map these IDs to AckSenders
*/
class LibSignalChatConnection(
name: String,
@@ -73,10 +74,10 @@ class LibSignalChatConnection(
// tell us to send a response for that ID, and then we use the pseudo ID as a handle to find
// the callback given to us earlier by libsignal-net, and we call that callback.
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatListener.ServerMessageAck>()
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatConnectionListener.ServerMessageAck>()
private val CHAT_SERVICE_LOCK = ReentrantLock()
private var chatService: ChatService? = null
private var chatConnection: ChatConnection? = null
companion object {
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
@@ -142,23 +143,38 @@ class LibSignalChatConnection(
// There's no sense in resetting nextIncomingMessageInternalPseudoId.
}
init {
if (credentialsProvider != null) {
check(!credentialsProvider.username.isNullOrEmpty())
check(!credentialsProvider.password.isNullOrEmpty())
}
}
override fun connect(): Observable<WebSocketConnectionState> {
CHAT_SERVICE_LOCK.withLock {
if (chatService != null) {
if (!isDead()) {
return state
}
Log.i(TAG, "$name Connecting...")
chatService = network.createChatService(credentialsProvider, receiveStories, listener).apply {
state.onNext(WebSocketConnectionState.CONNECTING)
connect().whenComplete(
onSuccess = { debugInfo ->
val chatConnectionFuture: CompletableFuture<out ChatConnection> = if (credentialsProvider == null) {
network.connectUnauthChat(listener)
} else {
network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener)
}
state.onNext(WebSocketConnectionState.CONNECTING)
chatConnectionFuture.whenComplete(
onSuccess = { connection ->
CHAT_SERVICE_LOCK.withLock {
chatConnection = connection
connection?.start()
Log.i(TAG, "$name Connected")
Log.d(TAG, "$name $debugInfo")
state.onNext(WebSocketConnectionState.CONNECTED)
},
onFailure = { throwable ->
}
},
onFailure = { throwable ->
CHAT_SERVICE_LOCK.withLock {
Log.w(TAG, "$name [connect] Failure:", throwable)
chatConnection = null
// Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT
// request returns HTTP 403.
// The chat service currently does not return HTTP 401 on /v1/websocket.
@@ -169,27 +185,38 @@ class LibSignalChatConnection(
state.onNext(WebSocketConnectionState.FAILED)
}
}
)
}
}
)
return state
}
}
override fun isDead(): Boolean {
CHAT_SERVICE_LOCK.withLock {
return chatService == null
return when (state.value) {
WebSocketConnectionState.DISCONNECTED,
WebSocketConnectionState.DISCONNECTING,
WebSocketConnectionState.FAILED,
WebSocketConnectionState.AUTHENTICATION_FAILED -> true
WebSocketConnectionState.CONNECTING,
WebSocketConnectionState.CONNECTED,
WebSocketConnectionState.RECONNECTING -> false
null -> throw IllegalStateException("LibSignalChatConnection.state can never be null")
}
}
}
override fun disconnect() {
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
if (isDead()) {
return
}
Log.i(TAG, "$name Disconnecting...")
state.onNext(WebSocketConnectionState.DISCONNECTING)
chatService!!.disconnect()
chatConnection!!.disconnect()
.whenComplete(
onSuccess = {
Log.i(TAG, "$name Disconnected")
@@ -200,18 +227,18 @@ class LibSignalChatConnection(
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
chatService = null
chatConnection = null
}
}
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
if (isDead()) {
return Single.error(IOException("$name is closed!"))
}
val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest()
chatService!!.send(internalRequest)
chatConnection!!.send(internalRequest)
.whenComplete(
onSuccess = { response ->
Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}")
@@ -219,13 +246,13 @@ class LibSignalChatConnection(
in 400..599 -> {
healthMonitor.onMessageError(
status = response.status,
isIdentifiedWebSocket = chatService is AuthenticatedChatService
isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection
)
}
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection)))
},
onFailure = { throwable ->
Log.w(TAG, "$name [sendRequest] Failure:", throwable)
@@ -238,29 +265,29 @@ class LibSignalChatConnection(
override fun sendKeepAlive() {
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
if (isDead()) {
return
}
Log.i(TAG, "$name Sending keep alive...")
chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
chatConnection!!.send(KEEP_ALIVE_REQUEST)
.whenComplete(
onSuccess = { debugResponse ->
onSuccess = { response ->
Log.d(TAG, "$name [sendKeepAlive] Success")
when (debugResponse!!.response.status) {
when (response!!.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
sentTimestamp = Instant.now().toEpochMilli(), // ignored. can be any value
isIdentifiedWebSocket = chatService is AuthenticatedChatService
isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection
)
}
in 400..599 -> {
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
healthMonitor.onMessageError(response.status, (chatConnection is AuthenticatedChatConnection))
}
else -> {
Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${debugResponse.response.status}")
Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${response.status}")
}
}
},
@@ -310,8 +337,8 @@ class LibSignalChatConnection(
private val listener = LibSignalChatListener()
private inner class LibSignalChatListener : ChatListener {
override fun onIncomingMessage(chat: ChatService, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) {
private inner class LibSignalChatListener : ChatConnectionListener {
override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) {
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
// already in the ackSender map, if it exists.
val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement()
@@ -328,15 +355,15 @@ class LibSignalChatConnection(
incomingRequestQueue.put(incomingWebSocketRequest)
}
override fun onConnectionInterrupted(chat: ChatService, disconnectReason: ChatServiceException) {
override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException) {
CHAT_SERVICE_LOCK.withLock {
Log.i(TAG, "$name connection interrupted", disconnectReason)
chatService = null
chatConnection = null
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
}
override fun onQueueEmpty(chat: ChatService) {
override fun onQueueEmpty(chat: ChatConnection) {
val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement()
val queueEmptyRequest = WebSocketRequestMessage(
verb = SOCKET_EMPTY_REQUEST_VERB,

View File

@@ -7,29 +7,9 @@
package org.whispersystems.signalservice.internal.websocket
import org.signal.core.util.orNull
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.util.CredentialsProvider
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration
/**
* Helper method to create a ChatService with optional credentials.
*/
fun Network.createChatService(
credentialsProvider: CredentialsProvider? = null,
receiveStories: Boolean,
listener: ChatListener? = null
): ChatService {
val username = credentialsProvider?.username ?: ""
val password = credentialsProvider?.password ?: ""
return if (username.isEmpty() && password.isEmpty()) {
this.createUnauthChatService(listener)
} else {
this.createAuthChatService(username, password, receiveStories, listener)
}
}
/**
* Helper method to apply settings from the SignalServiceConfiguration.
*/

View File

@@ -10,7 +10,9 @@ import io.reactivex.rxjava3.core.Single
import okhttp3.Response
import okhttp3.WebSocket
import org.signal.core.util.logging.Log
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatConnection
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatConnection
import org.whispersystems.signalservice.api.util.CredentialsProvider
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
@@ -46,7 +48,7 @@ class ShadowingWebSocketConnection(
signalAgent: String,
healthMonitor: HealthMonitor,
allowStories: Boolean,
private val chatService: ChatService,
private val network: Network,
private val shadowPercentage: Int,
private val bridge: WebSocketShadowingBridge
) : OkHttpWebSocketConnection(
@@ -67,19 +69,33 @@ class ShadowingWebSocketConnection(
}
private val canShadow: AtomicBoolean = AtomicBoolean(false)
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
private var chatConnection: UnauthenticatedChatConnection? = null
private var shadowingConnectPending = false
override fun connect(): Observable<WebSocketConnectionState> {
executor.submit {
chatService.connect().whenComplete(
onSuccess = {
canShadow.set(true)
Log.i(TAG, "Shadow socket connected.")
},
onFailure = {
canShadow.set(false)
Log.i(TAG, "Shadow socket failed to connect.")
}
)
// NB: The potential for race conditions here was introduced when we switched from ChatService's
// long lived connection model to the single-use ChatConnection model.
// At this time, we do not intend to ever use this code in production again, so I'm deferring properly
// fixing it with a refactor, and instead just doing the bare minimum to avoid an obvious race.
// If we do want to use this again in production, we should probably refactor to depend on the higher level
// LibSignalChatConnection, rather than the lower level ChatConnection API.
if (chatConnection == null && !shadowingConnectPending) {
shadowingConnectPending = true
executor.submit {
network.connectUnauthChat(null).whenComplete(
onSuccess = { connection ->
shadowingConnectPending = false
chatConnection = connection
canShadow.set(true)
Log.i(TAG, "Shadow socket connected.")
},
onFailure = {
shadowingConnectPending = false
canShadow.set(false)
Log.i(TAG, "Shadow socket failed to connect.")
}
)
}
}
return super.connect()
}
@@ -96,7 +112,7 @@ class ShadowingWebSocketConnection(
override fun disconnect() {
executor.submit {
chatService.disconnect().thenApply {
chatConnection?.disconnect()?.thenApply {
canShadow.set(false)
Log.i(TAG, "Shadow socket disconnected.")
}
@@ -133,22 +149,23 @@ class ShadowingWebSocketConnection(
}
private fun libsignalKeepAlive(actualResponse: WebsocketResponse) {
val request = ChatService.Request(
val connection = chatConnection ?: return
val request = ChatConnection.Request(
"GET",
"/v1/keepalive",
emptyMap(),
ByteArray(0),
KEEP_ALIVE_TIMEOUT.inWholeMilliseconds.toInt()
)
chatService.sendAndDebug(request)
.whenComplete(
onSuccess = {
connection.send(request)
?.whenComplete(
onSuccess = { response ->
stats.requestsCompared.incrementAndGet()
val goodStatus = (it?.response?.status ?: -1) in 200..299
val goodStatus = (response?.status ?: -1) in 200..299
if (!goodStatus) {
stats.badStatuses.incrementAndGet()
}
Log.i(TAG, "$it")
Log.i(TAG, response?.message)
},
onFailure = {
stats.requestsCompared.incrementAndGet()

View File

@@ -1,9 +1,9 @@
package org.whispersystems.signalservice.internal.websocket
import io.mockk.clearAllMocks
import io.mockk.clearMocks
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.verify
import io.reactivex.rxjava3.observers.TestObserver
import okio.ByteString.Companion.toByteString
@@ -12,12 +12,11 @@ import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatService.DebugInfo
import org.signal.libsignal.net.ChatConnection
import org.signal.libsignal.net.ChatConnectionListener
import org.signal.libsignal.net.ChatServiceException
import org.signal.libsignal.net.IpType
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatConnection
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.util.concurrent.CountDownLatch
@@ -25,51 +24,76 @@ import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse
class LibSignalChatConnectionTest {
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
private val healthMonitor = mockk<HealthMonitor>()
private val chatService = mockk<ChatService>()
private val network = mockk<Network>()
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
private var chatListener: ChatListener? = null
private val chatConnection = mockk<UnauthenticatedChatConnection>()
private var chatListener: ChatConnectionListener? = null
// Used by default-success mocks for ChatConnection behavior.
private var connectLatch: CountDownLatch? = null
private var disconnectLatch: CountDownLatch? = null
private var sendLatch: CountDownLatch? = null
private fun setupConnectedConnection() {
connectLatch = CountDownLatch(1)
connection.connect()
connectLatch!!.await(100, TimeUnit.MILLISECONDS)
}
@Before
fun before() {
clearAllMocks()
mockkStatic(Network::createChatService)
every { healthMonitor.onMessageError(any(), any()) }
every { healthMonitor.onKeepAliveResponse(any(), any()) }
every { network.createChatService(any(), any(), any()) } answers {
// When mocking static methods in mockk, the mock target is included as the first
// argument in the answers block. This results in the thirdArgument<T>() convenience method
// being off-by-one. Since we are interested in the last argument to createChatService, we need
// to manually fetch it from the args array and cast it ourselves.
chatListener = args[3] as ChatListener?
chatService
}
}
@Test
fun orderOfStatesOnSuccessfulConnect() {
val latch = CountDownLatch(1)
every { chatService.connect() } answers {
// NB: We provide default success behavior mocks here to cut down on boilerplate later, but it is
// expected that some tests will override some of these to test failures.
//
// We provide a null credentials provider when creating `connection`, so LibSignalChatConnection
// should always call connectUnauthChat()
// TODO: Maybe also test Auth? The old one didn't.
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
delay {
it.complete(DEBUG_INFO)
latch.countDown()
it.complete(chatConnection)
connectLatch?.countDown()
}
}
every { chatConnection.disconnect() } answers {
delay {
it.complete(null)
disconnectLatch?.countDown()
}
}
every { chatConnection.send(any()) } answers {
delay {
it.complete(RESPONSE_SUCCESS)
sendLatch?.countDown()
}
}
every { chatConnection.start() } returns Unit
}
// Test that the LibSignalChatConnection transitions through DISCONNECTED -> CONNECTING -> CONNECTED
// if the underlying ChatConnection future completes successfully.
@Test
fun orderOfStatesOnSuccessfulConnect() {
connectLatch = CountDownLatch(1)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
connection.connect()
latch.await(100, TimeUnit.MILLISECONDS)
connectLatch!!.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete()
observer.assertValues(
@@ -79,14 +103,18 @@ class LibSignalChatConnectionTest {
)
}
// Test that the LibSignalChatConnection transitions to FAILED if the
// underlying ChatConnection future completes exceptionally.
@Test
fun orderOfStatesOnConnectionFailure() {
val connectionException = RuntimeException("connect failed")
val latch = CountDownLatch(1)
every { chatService.connect() } answers {
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
delay {
it.completeExceptionally(connectionException)
latch.countDown()
}
}
@@ -105,32 +133,21 @@ class LibSignalChatConnectionTest {
)
}
// Test connect followed by disconnect, checking the state transitions.
@Test
fun orderOfStatesOnConnectAndDisconnect() {
val connectLatch = CountDownLatch(1)
val disconnectLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
every { chatService.disconnect() } answers {
delay {
it.complete(null)
disconnectLatch.countDown()
}
}
connectLatch = CountDownLatch(1)
disconnectLatch = CountDownLatch(1)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
connectLatch!!.await(100, TimeUnit.MILLISECONDS)
connection.disconnect()
disconnectLatch.await(100, TimeUnit.MILLISECONDS)
disconnectLatch!!.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete()
observer.assertValues(
@@ -142,30 +159,21 @@ class LibSignalChatConnectionTest {
)
}
// Test that a disconnect failure transitions from CONNECTED -> DISCONNECTING -> DISCONNECTED anyway,
// since we don't have a specific "DISCONNECT_FAILED" state.
@Test
fun orderOfStatesOnDisconnectFailure() {
val disconnectException = RuntimeException("disconnect failed")
val connectLatch = CountDownLatch(1)
val disconnectLatch = CountDownLatch(1)
every { chatService.disconnect() } answers {
every { chatConnection.disconnect() } answers {
delay {
it.completeExceptionally(disconnectException)
disconnectLatch.countDown()
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
setupConnectedConnection()
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
@@ -176,34 +184,23 @@ class LibSignalChatConnectionTest {
observer.assertNotComplete()
observer.assertValues(
// The subscriber is created after we've already connected, so the first state it sees is CONNECTED:
WebSocketConnectionState.CONNECTED,
WebSocketConnectionState.DISCONNECTING,
WebSocketConnectionState.DISCONNECTED
)
}
// Test a successful keepAlive, i.e. we get a 200 OK in response to the keepAlive request,
// which triggers healthMonitor.onKeepAliveResponse(...) and not onMessageError.
@Test
fun keepAliveSuccess() {
val latch = CountDownLatch(1)
setupConnectedConnection()
every { chatService.sendAndDebug(any()) } answers {
delay {
it.complete(make_debug_response(RESPONSE_SUCCESS))
latch.countDown()
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
connection.connect()
sendLatch = CountDownLatch(1)
connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS)
sendLatch!!.await(100, TimeUnit.MILLISECONDS)
verify(exactly = 1) {
healthMonitor.onKeepAliveResponse(any(), false)
@@ -213,27 +210,25 @@ class LibSignalChatConnectionTest {
}
}
// Test keepAlive failures: we get 4xx or 5xx, which triggers healthMonitor.onMessageError(...) but not onKeepAliveResponse.
@Test
fun keepAliveFailure() {
for (response in listOf(RESPONSE_ERROR, RESPONSE_SERVER_ERROR)) {
val latch = CountDownLatch(1)
clearMocks(healthMonitor)
every { chatService.sendAndDebug(any()) } answers {
every { chatConnection.send(any()) } answers {
delay {
it.complete(make_debug_response(response))
it.complete(response)
sendLatch?.countDown()
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
setupConnectedConnection()
connection.connect()
sendLatch = CountDownLatch(1)
connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS)
sendLatch!!.await(100, TimeUnit.MILLISECONDS)
verify(exactly = 1) {
healthMonitor.onMessageError(response.status, false)
@@ -244,31 +239,22 @@ class LibSignalChatConnectionTest {
}
}
// Test keepAlive that fails at the transport layer (send() throws),
// which transitions from CONNECTED -> DISCONNECTED.
@Test
fun keepAliveConnectionFailure() {
val connectionFailure = RuntimeException("Sending keep-alive failed")
val connectLatch = CountDownLatch(1)
val keepAliveFailureLatch = CountDownLatch(1)
every {
chatService.sendAndDebug(any())
} answers {
every { chatConnection.send(any()) } answers {
delay {
it.completeExceptionally(connectionFailure)
keepAliveFailureLatch.countDown()
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
setupConnectedConnection()
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
@@ -290,58 +276,17 @@ class LibSignalChatConnectionTest {
}
}
// Test that an incoming "connection interrupted" event from ChatConnection sets our state to DISCONNECTED.
@Test
fun connectionInterruptedTest() {
val disconnectReason = ChatServiceException("simulated interrupt")
val connectLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
setupConnectedConnection()
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
chatListener!!.onConnectionInterrupted(chatService, disconnectReason)
observer.assertNotComplete()
observer.assertValues(
// We start in the connected state
WebSocketConnectionState.CONNECTED,
// Disconnects as a result of the connection interrupted event
WebSocketConnectionState.DISCONNECTED
)
verify(exactly = 0) {
healthMonitor.onKeepAliveResponse(any(), any())
healthMonitor.onMessageError(any(), any())
}
}
@Test
fun connectionInterrupted() {
val disconnectReason = ChatServiceException("simulated interrupt")
val connectLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
chatListener!!.onConnectionInterrupted(chatService, disconnectReason)
chatListener!!.onConnectionInterrupted(chatConnection, disconnectReason)
observer.assertNotComplete()
observer.assertValues(
@@ -356,36 +301,32 @@ class LibSignalChatConnectionTest {
}
}
// Test reading incoming requests from the queue.
// We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection.
@Test
fun incomingRequests() {
val connectLatch = CountDownLatch(1)
val asyncMessageReadLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
setupConnectedConnection()
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
// Confirm that readRequest times out if there's no message.
var timedOut = false
try {
connection.readRequest(10)
} catch (e: TimeoutException) {
timedOut = true
}
assert(timedOut)
assertTrue(timedOut)
// We'll now simulate incoming messages
val envelopeA = "msgA".toByteArray()
val envelopeB = "msgB".toByteArray()
val envelopeC = "msgC".toByteArray()
val asyncMessageReadLatch = CountDownLatch(1)
// Helper to check that the WebSocketRequestMessage for an envelope is as expected
fun assertRequestWithEnvelope(request: WebSocketRequestMessage, envelope: ByteArray) {
assertEquals("PUT", request.verb)
assertEquals("/api/v1/message", request.path)
@@ -399,6 +340,7 @@ class LibSignalChatConnectionTest {
)
}
// Helper to check that a queue-empty request is as expected
fun assertQueueEmptyRequest(request: WebSocketRequestMessage) {
assertEquals("PUT", request.verb)
assertEquals("/api/v1/queue/empty", request.path)
@@ -411,20 +353,23 @@ class LibSignalChatConnectionTest {
)
}
// Read request asynchronously to simulate concurrency
executor.submit {
assertRequestWithEnvelope(connection.readRequest(10), envelopeA)
val request = connection.readRequest(200)
assertRequestWithEnvelope(request, envelopeA)
asyncMessageReadLatch.countDown()
}
chatListener!!.onIncomingMessage(chatService, envelopeA, 0, null)
chatListener!!.onIncomingMessage(chatConnection, envelopeA, 0, null)
asyncMessageReadLatch.await(100, TimeUnit.MILLISECONDS)
chatListener!!.onIncomingMessage(chatService, envelopeB, 0, null)
chatListener!!.onIncomingMessage(chatConnection, envelopeB, 0, null)
assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeB)
chatListener!!.onQueueEmpty(chatService)
chatListener!!.onQueueEmpty(chatConnection)
assertQueueEmptyRequest(connection.readRequestIfAvailable().get())
chatListener!!.onIncomingMessage(chatService, envelopeC, 0, null)
chatListener!!.onIncomingMessage(chatConnection, envelopeC, 0, null)
assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeC)
assertTrue(connection.readRequestIfAvailable().isEmpty)
@@ -439,13 +384,9 @@ class LibSignalChatConnectionTest {
}
companion object {
private val DEBUG_INFO: DebugInfo = DebugInfo(IpType.UNKNOWN, 100, "")
private val RESPONSE_SUCCESS = LibSignalResponse(200, "", emptyMap(), byteArrayOf())
private val RESPONSE_ERROR = LibSignalResponse(400, "", emptyMap(), byteArrayOf())
private val RESPONSE_SERVER_ERROR = LibSignalResponse(500, "", emptyMap(), byteArrayOf())
private fun make_debug_response(response: LibSignalResponse): LibSignalDebugResponse {
return LibSignalDebugResponse(response, DEBUG_INFO)
}
// For verifying success / error scenarios in keepAlive tests, etc.
private val RESPONSE_SUCCESS = ChatConnection.Response(200, "", emptyMap(), byteArrayOf())
private val RESPONSE_ERROR = ChatConnection.Response(400, "", emptyMap(), byteArrayOf())
private val RESPONSE_SERVER_ERROR = ChatConnection.Response(500, "", emptyMap(), byteArrayOf())
}
}