Make LibSignalChatConnection Only Use Each ChatService Once

This commit is contained in:
andrew-signal
2024-11-20 12:05:11 -05:00
committed by Greyson Parrelli
parent 040d05a0a6
commit 1401256ffd
3 changed files with 165 additions and 84 deletions

View File

@@ -434,7 +434,9 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
Network network = libSignalNetworkSupplier.get(); Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection( return new LibSignalChatConnection(
"libsignal-unauth", "libsignal-unauth",
LibSignalNetworkExtensions.createChatService(network, null, Stories.isFeatureEnabled()), network,
null,
Stories.isFeatureEnabled(),
healthMonitor); healthMonitor);
} else { } else {
return new OkHttpWebSocketConnection("unidentified", return new OkHttpWebSocketConnection("unidentified",

View File

@@ -13,12 +13,17 @@ import io.reactivex.rxjava3.subjects.SingleSubject
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.libsignal.net.AuthenticatedChatService import org.signal.libsignal.net.AuthenticatedChatService
import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatService import org.signal.libsignal.net.UnauthenticatedChatService
import org.whispersystems.signalservice.api.util.CredentialsProvider
import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import org.whispersystems.signalservice.internal.util.whenComplete import org.whispersystems.signalservice.internal.util.whenComplete
import java.io.IOException
import java.time.Instant import java.time.Instant
import java.util.Optional import java.util.Optional
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.seconds import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
@@ -38,10 +43,15 @@ import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
*/ */
class LibSignalChatConnection( class LibSignalChatConnection(
name: String, name: String,
private val chatService: ChatService, private val network: Network,
private val credentialsProvider: CredentialsProvider?,
private val receiveStories: Boolean,
private val healthMonitor: HealthMonitor private val healthMonitor: HealthMonitor
) : WebSocketConnection { ) : WebSocketConnection {
private val CHAT_SERVICE_LOCK = ReentrantLock()
private var chatService: ChatService? = null
companion object { companion object {
private val TAG = Log.tag(LibSignalChatConnection::class.java) private val TAG = Log.tag(LibSignalChatConnection::class.java)
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
@@ -85,95 +95,118 @@ class LibSignalChatConnection(
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
override fun connect(): Observable<WebSocketConnectionState> { override fun connect(): Observable<WebSocketConnectionState> {
Log.i(TAG, "$name Connecting...") CHAT_SERVICE_LOCK.withLock {
state.onNext(WebSocketConnectionState.CONNECTING) if (chatService != null) {
chatService.connect() return state
.whenComplete( }
onSuccess = { debugInfo ->
Log.i(TAG, "$name Connected") Log.i(TAG, "$name Connecting...")
Log.d(TAG, "$name $debugInfo") chatService = network.createChatService(credentialsProvider, receiveStories).apply {
state.onNext(WebSocketConnectionState.CONNECTED) state.onNext(WebSocketConnectionState.CONNECTING)
}, connect().whenComplete(
onFailure = { throwable -> onSuccess = { debugInfo ->
// TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors Log.i(TAG, "$name Connected")
Log.d(TAG, "$name Connect failed", throwable) Log.d(TAG, "$name $debugInfo")
state.onNext(WebSocketConnectionState.FAILED) state.onNext(WebSocketConnectionState.CONNECTED)
} },
) onFailure = { throwable ->
return state // TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors
Log.w(TAG, "$name Connect failed", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
)
}
return state
}
} }
override fun isDead(): Boolean = false override fun isDead(): Boolean = false
override fun disconnect() { override fun disconnect() {
Log.i(TAG, "$name Disconnecting...") CHAT_SERVICE_LOCK.withLock {
state.onNext(WebSocketConnectionState.DISCONNECTING) if (chatService == null) {
chatService.disconnect() return
.whenComplete( }
onSuccess = {
Log.i(TAG, "$name Disconnected") Log.i(TAG, "$name Disconnecting...")
state.onNext(WebSocketConnectionState.DISCONNECTED) state.onNext(WebSocketConnectionState.DISCONNECTING)
}, chatService!!.disconnect()
onFailure = { throwable -> .whenComplete(
Log.d(TAG, "$name Disconnect failed", throwable) onSuccess = {
state.onNext(WebSocketConnectionState.DISCONNECTED) Log.i(TAG, "$name Disconnected")
} state.onNext(WebSocketConnectionState.DISCONNECTED)
) },
onFailure = { throwable ->
Log.w(TAG, "$name Disconnect failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
chatService = null
}
} }
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> { override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
val single = SingleSubject.create<WebsocketResponse>() CHAT_SERVICE_LOCK.withLock {
val internalRequest = request.toLibSignalRequest() if (chatService == null) {
chatService.send(internalRequest) return Single.error(IOException("[$name] is closed!"))
.whenComplete( }
onSuccess = { response -> val single = SingleSubject.create<WebsocketResponse>()
when (response!!.status) { val internalRequest = request.toLibSignalRequest()
in 400..599 -> { chatService!!.send(internalRequest)
healthMonitor.onMessageError(response.status, false) .whenComplete(
onSuccess = { response ->
when (response!!.status) {
in 400..599 -> {
healthMonitor.onMessageError(response.status, false)
}
} }
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
},
onFailure = { throwable ->
Log.w(TAG, "$name sendRequest failed", throwable)
single.onError(throwable)
} }
// Here success means "we received the response" even if it is reporting an error. )
// This is consistent with the behavior of the OkHttpWebSocketConnection. return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) }
},
onFailure = { throwable ->
Log.i(TAG, "$name sendRequest failed", throwable)
single.onError(throwable)
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
} }
override fun sendKeepAlive() { override fun sendKeepAlive() {
Log.i(TAG, "$name Sending keep alive...") CHAT_SERVICE_LOCK.withLock {
chatService.sendAndDebug(KEEP_ALIVE_REQUEST) if (chatService == null) {
.whenComplete( return
onSuccess = { debugResponse -> }
Log.i(TAG, "$name Keep alive - success")
Log.d(TAG, "$name $debugResponse")
when (debugResponse!!.response.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value
false
)
}
in 400..599 -> { Log.i(TAG, "$name Sending keep alive...")
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService)) chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
} .whenComplete(
onSuccess = { debugResponse ->
Log.d(TAG, "$name Keep alive - success")
when (debugResponse!!.response.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value
false
)
}
else -> { in 400..599 -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
}
else -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
}
} }
},
onFailure = { throwable ->
Log.w(TAG, "$name Keep alive - failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
} }
}, )
onFailure = { throwable -> }
Log.i(TAG, "$name Keep alive - failed")
Log.d(TAG, "$name $throwable")
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
} }
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> { override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {

View File

@@ -3,6 +3,7 @@ package org.whispersystems.signalservice.internal.websocket
import io.mockk.clearAllMocks import io.mockk.clearAllMocks
import io.mockk.every import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.verify import io.mockk.verify
import io.reactivex.rxjava3.observers.TestObserver import io.reactivex.rxjava3.observers.TestObserver
import org.junit.Before import org.junit.Before
@@ -11,6 +12,7 @@ import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatService.DebugInfo import org.signal.libsignal.net.ChatService.DebugInfo
import org.signal.libsignal.net.IpType import org.signal.libsignal.net.IpType
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
@@ -25,13 +27,16 @@ class LibSignalChatConnectionTest {
private val executor: ExecutorService = Executors.newSingleThreadExecutor() private val executor: ExecutorService = Executors.newSingleThreadExecutor()
private val healthMonitor = mockk<HealthMonitor>() private val healthMonitor = mockk<HealthMonitor>()
private val chatService = mockk<ChatService>() private val chatService = mockk<ChatService>()
private val connection = LibSignalChatConnection("test", chatService, healthMonitor) private val network = mockk<Network>()
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
@Before @Before
fun before() { fun before() {
clearAllMocks() clearAllMocks()
mockkStatic(Network::createChatService)
every { healthMonitor.onMessageError(any(), any()) } every { healthMonitor.onMessageError(any(), any()) }
every { healthMonitor.onKeepAliveResponse(any(), any()) } every { healthMonitor.onKeepAliveResponse(any(), any()) }
every { network.createChatService(any(), any()) } answers { chatService }
} }
@Test @Test
@@ -127,25 +132,37 @@ class LibSignalChatConnectionTest {
fun orderOfStatesOnDisconnectFailure() { fun orderOfStatesOnDisconnectFailure() {
val disconnectException = RuntimeException("disconnect failed") val disconnectException = RuntimeException("disconnect failed")
val latch = CountDownLatch(1) val connectLatch = CountDownLatch(1)
val disconnectLatch = CountDownLatch(1)
every { chatService.disconnect() } answers { every { chatService.disconnect() } answers {
delay { delay {
it.completeExceptionally(disconnectException) it.completeExceptionally(disconnectException)
disconnectLatch.countDown()
} }
} }
val observer = TestObserver<WebSocketConnectionState>() every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer) connection.state.subscribe(observer)
connection.disconnect() connection.disconnect()
latch.await(100, TimeUnit.MILLISECONDS) disconnectLatch.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete() observer.assertNotComplete()
observer.assertValues( observer.assertValues(
WebSocketConnectionState.DISCONNECTED, WebSocketConnectionState.CONNECTED,
WebSocketConnectionState.DISCONNECTING, WebSocketConnectionState.DISCONNECTING,
WebSocketConnectionState.DISCONNECTED WebSocketConnectionState.DISCONNECTED
) )
@@ -162,6 +179,14 @@ class LibSignalChatConnectionTest {
} }
} }
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
connection.connect()
connection.sendKeepAlive() connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS) latch.await(100, TimeUnit.MILLISECONDS)
@@ -185,6 +210,14 @@ class LibSignalChatConnectionTest {
} }
} }
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
connection.connect()
connection.sendKeepAlive() connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS) latch.await(100, TimeUnit.MILLISECONDS)
@@ -200,28 +233,41 @@ class LibSignalChatConnectionTest {
@Test @Test
fun keepAliveConnectionFailure() { fun keepAliveConnectionFailure() {
val connectionFailure = RuntimeException("Sending keep-alive failed") val connectionFailure = RuntimeException("Sending keep-alive failed")
val latch = CountDownLatch(1)
val connectLatch = CountDownLatch(1)
val keepAliveFailureLatch = CountDownLatch(1)
every { every {
chatService.sendAndDebug(any()) chatService.sendAndDebug(any())
} answers { } answers {
delay { delay {
it.completeExceptionally(connectionFailure) it.completeExceptionally(connectionFailure)
keepAliveFailureLatch.countDown()
} }
} }
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>() val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer) connection.state.subscribe(observer)
connection.sendKeepAlive() connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS) keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete() observer.assertNotComplete()
observer.assertValues( observer.assertValues(
// This is the starting state // We start in the connected state
WebSocketConnectionState.DISCONNECTED, WebSocketConnectionState.CONNECTED,
// This one is the result of a keep-alive failure // Disconnects as a result of keep-alive failure
WebSocketConnectionState.DISCONNECTED WebSocketConnectionState.DISCONNECTED
) )
verify(exactly = 0) { verify(exactly = 0) {