mirror of
https://github.com/signalapp/Signal-Android.git
synced 2025-12-23 04:28:35 +00:00
Make LibSignalChatConnection Only Use Each ChatService Once
This commit is contained in:
committed by
Greyson Parrelli
parent
040d05a0a6
commit
1401256ffd
@@ -434,7 +434,9 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
|
|||||||
Network network = libSignalNetworkSupplier.get();
|
Network network = libSignalNetworkSupplier.get();
|
||||||
return new LibSignalChatConnection(
|
return new LibSignalChatConnection(
|
||||||
"libsignal-unauth",
|
"libsignal-unauth",
|
||||||
LibSignalNetworkExtensions.createChatService(network, null, Stories.isFeatureEnabled()),
|
network,
|
||||||
|
null,
|
||||||
|
Stories.isFeatureEnabled(),
|
||||||
healthMonitor);
|
healthMonitor);
|
||||||
} else {
|
} else {
|
||||||
return new OkHttpWebSocketConnection("unidentified",
|
return new OkHttpWebSocketConnection("unidentified",
|
||||||
|
|||||||
@@ -13,12 +13,17 @@ import io.reactivex.rxjava3.subjects.SingleSubject
|
|||||||
import org.signal.core.util.logging.Log
|
import org.signal.core.util.logging.Log
|
||||||
import org.signal.libsignal.net.AuthenticatedChatService
|
import org.signal.libsignal.net.AuthenticatedChatService
|
||||||
import org.signal.libsignal.net.ChatService
|
import org.signal.libsignal.net.ChatService
|
||||||
|
import org.signal.libsignal.net.Network
|
||||||
import org.signal.libsignal.net.UnauthenticatedChatService
|
import org.signal.libsignal.net.UnauthenticatedChatService
|
||||||
|
import org.whispersystems.signalservice.api.util.CredentialsProvider
|
||||||
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
||||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
||||||
import org.whispersystems.signalservice.internal.util.whenComplete
|
import org.whispersystems.signalservice.internal.util.whenComplete
|
||||||
|
import java.io.IOException
|
||||||
import java.time.Instant
|
import java.time.Instant
|
||||||
import java.util.Optional
|
import java.util.Optional
|
||||||
|
import java.util.concurrent.locks.ReentrantLock
|
||||||
|
import kotlin.concurrent.withLock
|
||||||
import kotlin.time.Duration.Companion.seconds
|
import kotlin.time.Duration.Companion.seconds
|
||||||
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
|
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
|
||||||
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
|
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
|
||||||
@@ -38,10 +43,15 @@ import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
|
|||||||
*/
|
*/
|
||||||
class LibSignalChatConnection(
|
class LibSignalChatConnection(
|
||||||
name: String,
|
name: String,
|
||||||
private val chatService: ChatService,
|
private val network: Network,
|
||||||
|
private val credentialsProvider: CredentialsProvider?,
|
||||||
|
private val receiveStories: Boolean,
|
||||||
private val healthMonitor: HealthMonitor
|
private val healthMonitor: HealthMonitor
|
||||||
) : WebSocketConnection {
|
) : WebSocketConnection {
|
||||||
|
|
||||||
|
private val CHAT_SERVICE_LOCK = ReentrantLock()
|
||||||
|
private var chatService: ChatService? = null
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = Log.tag(LibSignalChatConnection::class.java)
|
private val TAG = Log.tag(LibSignalChatConnection::class.java)
|
||||||
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
|
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
|
||||||
@@ -85,46 +95,64 @@ class LibSignalChatConnection(
|
|||||||
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
|
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
|
||||||
|
|
||||||
override fun connect(): Observable<WebSocketConnectionState> {
|
override fun connect(): Observable<WebSocketConnectionState> {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
if (chatService != null) {
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
Log.i(TAG, "$name Connecting...")
|
Log.i(TAG, "$name Connecting...")
|
||||||
|
chatService = network.createChatService(credentialsProvider, receiveStories).apply {
|
||||||
state.onNext(WebSocketConnectionState.CONNECTING)
|
state.onNext(WebSocketConnectionState.CONNECTING)
|
||||||
chatService.connect()
|
connect().whenComplete(
|
||||||
.whenComplete(
|
|
||||||
onSuccess = { debugInfo ->
|
onSuccess = { debugInfo ->
|
||||||
Log.i(TAG, "$name Connected")
|
Log.i(TAG, "$name Connected")
|
||||||
Log.d(TAG, "$name $debugInfo")
|
Log.d(TAG, "$name $debugInfo")
|
||||||
state.onNext(WebSocketConnectionState.CONNECTED)
|
state.onNext(WebSocketConnectionState.CONNECTED)
|
||||||
},
|
},
|
||||||
onFailure = { throwable ->
|
onFailure = { throwable ->
|
||||||
// TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors
|
// TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors
|
||||||
Log.d(TAG, "$name Connect failed", throwable)
|
Log.w(TAG, "$name Connect failed", throwable)
|
||||||
state.onNext(WebSocketConnectionState.FAILED)
|
state.onNext(WebSocketConnectionState.FAILED)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
}
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun isDead(): Boolean = false
|
override fun isDead(): Boolean = false
|
||||||
|
|
||||||
override fun disconnect() {
|
override fun disconnect() {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
if (chatService == null) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
Log.i(TAG, "$name Disconnecting...")
|
Log.i(TAG, "$name Disconnecting...")
|
||||||
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
||||||
chatService.disconnect()
|
chatService!!.disconnect()
|
||||||
.whenComplete(
|
.whenComplete(
|
||||||
onSuccess = {
|
onSuccess = {
|
||||||
Log.i(TAG, "$name Disconnected")
|
Log.i(TAG, "$name Disconnected")
|
||||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||||
},
|
},
|
||||||
onFailure = { throwable ->
|
onFailure = { throwable ->
|
||||||
Log.d(TAG, "$name Disconnect failed", throwable)
|
Log.w(TAG, "$name Disconnect failed", throwable)
|
||||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
chatService = null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
|
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
if (chatService == null) {
|
||||||
|
return Single.error(IOException("[$name] is closed!"))
|
||||||
|
}
|
||||||
val single = SingleSubject.create<WebsocketResponse>()
|
val single = SingleSubject.create<WebsocketResponse>()
|
||||||
val internalRequest = request.toLibSignalRequest()
|
val internalRequest = request.toLibSignalRequest()
|
||||||
chatService.send(internalRequest)
|
chatService!!.send(internalRequest)
|
||||||
.whenComplete(
|
.whenComplete(
|
||||||
onSuccess = { response ->
|
onSuccess = { response ->
|
||||||
when (response!!.status) {
|
when (response!!.status) {
|
||||||
@@ -137,20 +165,25 @@ class LibSignalChatConnection(
|
|||||||
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
|
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
|
||||||
},
|
},
|
||||||
onFailure = { throwable ->
|
onFailure = { throwable ->
|
||||||
Log.i(TAG, "$name sendRequest failed", throwable)
|
Log.w(TAG, "$name sendRequest failed", throwable)
|
||||||
single.onError(throwable)
|
single.onError(throwable)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
|
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun sendKeepAlive() {
|
override fun sendKeepAlive() {
|
||||||
|
CHAT_SERVICE_LOCK.withLock {
|
||||||
|
if (chatService == null) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
Log.i(TAG, "$name Sending keep alive...")
|
Log.i(TAG, "$name Sending keep alive...")
|
||||||
chatService.sendAndDebug(KEEP_ALIVE_REQUEST)
|
chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
|
||||||
.whenComplete(
|
.whenComplete(
|
||||||
onSuccess = { debugResponse ->
|
onSuccess = { debugResponse ->
|
||||||
Log.i(TAG, "$name Keep alive - success")
|
Log.d(TAG, "$name Keep alive - success")
|
||||||
Log.d(TAG, "$name $debugResponse")
|
|
||||||
when (debugResponse!!.response.status) {
|
when (debugResponse!!.response.status) {
|
||||||
in 200..299 -> {
|
in 200..299 -> {
|
||||||
healthMonitor.onKeepAliveResponse(
|
healthMonitor.onKeepAliveResponse(
|
||||||
@@ -169,12 +202,12 @@ class LibSignalChatConnection(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onFailure = { throwable ->
|
onFailure = { throwable ->
|
||||||
Log.i(TAG, "$name Keep alive - failed")
|
Log.w(TAG, "$name Keep alive - failed", throwable)
|
||||||
Log.d(TAG, "$name $throwable")
|
|
||||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {
|
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {
|
||||||
throw NotImplementedError()
|
throw NotImplementedError()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package org.whispersystems.signalservice.internal.websocket
|
|||||||
import io.mockk.clearAllMocks
|
import io.mockk.clearAllMocks
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
import io.mockk.mockk
|
import io.mockk.mockk
|
||||||
|
import io.mockk.mockkStatic
|
||||||
import io.mockk.verify
|
import io.mockk.verify
|
||||||
import io.reactivex.rxjava3.observers.TestObserver
|
import io.reactivex.rxjava3.observers.TestObserver
|
||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
@@ -11,6 +12,7 @@ import org.signal.libsignal.internal.CompletableFuture
|
|||||||
import org.signal.libsignal.net.ChatService
|
import org.signal.libsignal.net.ChatService
|
||||||
import org.signal.libsignal.net.ChatService.DebugInfo
|
import org.signal.libsignal.net.ChatService.DebugInfo
|
||||||
import org.signal.libsignal.net.IpType
|
import org.signal.libsignal.net.IpType
|
||||||
|
import org.signal.libsignal.net.Network
|
||||||
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
||||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
||||||
import java.util.concurrent.CountDownLatch
|
import java.util.concurrent.CountDownLatch
|
||||||
@@ -25,13 +27,16 @@ class LibSignalChatConnectionTest {
|
|||||||
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
|
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
|
||||||
private val healthMonitor = mockk<HealthMonitor>()
|
private val healthMonitor = mockk<HealthMonitor>()
|
||||||
private val chatService = mockk<ChatService>()
|
private val chatService = mockk<ChatService>()
|
||||||
private val connection = LibSignalChatConnection("test", chatService, healthMonitor)
|
private val network = mockk<Network>()
|
||||||
|
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
fun before() {
|
fun before() {
|
||||||
clearAllMocks()
|
clearAllMocks()
|
||||||
|
mockkStatic(Network::createChatService)
|
||||||
every { healthMonitor.onMessageError(any(), any()) }
|
every { healthMonitor.onMessageError(any(), any()) }
|
||||||
every { healthMonitor.onKeepAliveResponse(any(), any()) }
|
every { healthMonitor.onKeepAliveResponse(any(), any()) }
|
||||||
|
every { network.createChatService(any(), any()) } answers { chatService }
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -127,25 +132,37 @@ class LibSignalChatConnectionTest {
|
|||||||
fun orderOfStatesOnDisconnectFailure() {
|
fun orderOfStatesOnDisconnectFailure() {
|
||||||
val disconnectException = RuntimeException("disconnect failed")
|
val disconnectException = RuntimeException("disconnect failed")
|
||||||
|
|
||||||
val latch = CountDownLatch(1)
|
val connectLatch = CountDownLatch(1)
|
||||||
|
val disconnectLatch = CountDownLatch(1)
|
||||||
|
|
||||||
every { chatService.disconnect() } answers {
|
every { chatService.disconnect() } answers {
|
||||||
delay {
|
delay {
|
||||||
it.completeExceptionally(disconnectException)
|
it.completeExceptionally(disconnectException)
|
||||||
|
disconnectLatch.countDown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
every { chatService.connect() } answers {
|
||||||
|
delay {
|
||||||
|
it.complete(DEBUG_INFO)
|
||||||
|
connectLatch.countDown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connection.connect()
|
||||||
|
|
||||||
|
connectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||||
|
|
||||||
val observer = TestObserver<WebSocketConnectionState>()
|
val observer = TestObserver<WebSocketConnectionState>()
|
||||||
|
|
||||||
connection.state.subscribe(observer)
|
connection.state.subscribe(observer)
|
||||||
|
|
||||||
connection.disconnect()
|
connection.disconnect()
|
||||||
|
|
||||||
latch.await(100, TimeUnit.MILLISECONDS)
|
disconnectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||||
|
|
||||||
observer.assertNotComplete()
|
observer.assertNotComplete()
|
||||||
observer.assertValues(
|
observer.assertValues(
|
||||||
WebSocketConnectionState.DISCONNECTED,
|
WebSocketConnectionState.CONNECTED,
|
||||||
WebSocketConnectionState.DISCONNECTING,
|
WebSocketConnectionState.DISCONNECTING,
|
||||||
WebSocketConnectionState.DISCONNECTED
|
WebSocketConnectionState.DISCONNECTED
|
||||||
)
|
)
|
||||||
@@ -162,6 +179,14 @@ class LibSignalChatConnectionTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
every { chatService.connect() } answers {
|
||||||
|
delay {
|
||||||
|
it.complete(DEBUG_INFO)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connection.connect()
|
||||||
|
|
||||||
connection.sendKeepAlive()
|
connection.sendKeepAlive()
|
||||||
|
|
||||||
latch.await(100, TimeUnit.MILLISECONDS)
|
latch.await(100, TimeUnit.MILLISECONDS)
|
||||||
@@ -185,6 +210,14 @@ class LibSignalChatConnectionTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
every { chatService.connect() } answers {
|
||||||
|
delay {
|
||||||
|
it.complete(DEBUG_INFO)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connection.connect()
|
||||||
|
|
||||||
connection.sendKeepAlive()
|
connection.sendKeepAlive()
|
||||||
latch.await(100, TimeUnit.MILLISECONDS)
|
latch.await(100, TimeUnit.MILLISECONDS)
|
||||||
|
|
||||||
@@ -200,28 +233,41 @@ class LibSignalChatConnectionTest {
|
|||||||
@Test
|
@Test
|
||||||
fun keepAliveConnectionFailure() {
|
fun keepAliveConnectionFailure() {
|
||||||
val connectionFailure = RuntimeException("Sending keep-alive failed")
|
val connectionFailure = RuntimeException("Sending keep-alive failed")
|
||||||
val latch = CountDownLatch(1)
|
|
||||||
|
val connectLatch = CountDownLatch(1)
|
||||||
|
val keepAliveFailureLatch = CountDownLatch(1)
|
||||||
|
|
||||||
every {
|
every {
|
||||||
chatService.sendAndDebug(any())
|
chatService.sendAndDebug(any())
|
||||||
} answers {
|
} answers {
|
||||||
delay {
|
delay {
|
||||||
it.completeExceptionally(connectionFailure)
|
it.completeExceptionally(connectionFailure)
|
||||||
|
keepAliveFailureLatch.countDown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
every { chatService.connect() } answers {
|
||||||
|
delay {
|
||||||
|
it.complete(DEBUG_INFO)
|
||||||
|
connectLatch.countDown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connection.connect()
|
||||||
|
connectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||||
|
|
||||||
val observer = TestObserver<WebSocketConnectionState>()
|
val observer = TestObserver<WebSocketConnectionState>()
|
||||||
connection.state.subscribe(observer)
|
connection.state.subscribe(observer)
|
||||||
|
|
||||||
connection.sendKeepAlive()
|
connection.sendKeepAlive()
|
||||||
|
|
||||||
latch.await(100, TimeUnit.MILLISECONDS)
|
keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS)
|
||||||
|
|
||||||
observer.assertNotComplete()
|
observer.assertNotComplete()
|
||||||
observer.assertValues(
|
observer.assertValues(
|
||||||
// This is the starting state
|
// We start in the connected state
|
||||||
WebSocketConnectionState.DISCONNECTED,
|
WebSocketConnectionState.CONNECTED,
|
||||||
// This one is the result of a keep-alive failure
|
// Disconnects as a result of keep-alive failure
|
||||||
WebSocketConnectionState.DISCONNECTED
|
WebSocketConnectionState.DISCONNECTED
|
||||||
)
|
)
|
||||||
verify(exactly = 0) {
|
verify(exactly = 0) {
|
||||||
|
|||||||
Reference in New Issue
Block a user