diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/NetworkConnectionListener.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/NetworkConnectionListener.kt index 4d3a58f2ee..c9191cf5fa 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/NetworkConnectionListener.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/NetworkConnectionListener.kt @@ -44,7 +44,7 @@ class NetworkConnectionListener(private val context: Context, private val onNetw networkCapabilities: NetworkCapabilities, callbackType: String, lastLogs: MutableMap - ) { + ): Boolean { val currentLog = buildString { append(callbackType) append(" onCapabilitiesChanged($network, ") @@ -56,7 +56,10 @@ class NetworkConnectionListener(private val context: Context, private val onNetw if (lastLogs[network] != currentLog) { Log.d(TAG, currentLog) lastLogs[network] = currentLog + return true } + + return false } private val networkChangedCallback: ConnectivityManager.NetworkCallback = object : ConnectivityManager.NetworkCallback() { @@ -92,7 +95,9 @@ class NetworkConnectionListener(private val context: Context, private val onNetw override fun onCapabilitiesChanged(network: Network, networkCapabilities: NetworkCapabilities) { super.onCapabilitiesChanged(network, networkCapabilities) - logCapabilitiesIfChanged(network, networkCapabilities, "ConnectivityManager.NetworkCallback", lastNetworkCapabilities) + if (logCapabilitiesIfChanged(network, networkCapabilities, "ConnectivityManager.NetworkCallback", lastNetworkCapabilities)) { + onNetworkLost { !networkCapabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) } + } } } diff --git a/app/src/test/java/org/thoughtcrime/securesms/messages/NetworkConnectionListenerTest.kt b/app/src/test/java/org/thoughtcrime/securesms/messages/NetworkConnectionListenerTest.kt new file mode 100644 index 0000000000..b44838bd58 --- /dev/null +++ b/app/src/test/java/org/thoughtcrime/securesms/messages/NetworkConnectionListenerTest.kt @@ -0,0 +1,59 @@ +package org.thoughtcrime.securesms.messages + +import android.app.Application +import android.net.ConnectivityManager +import android.net.Network +import android.net.NetworkCapabilities +import androidx.test.core.app.ApplicationProvider +import assertk.assertThat +import assertk.assertions.containsExactly +import io.mockk.mockk +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import org.robolectric.util.ReflectionHelpers + +@RunWith(RobolectricTestRunner::class) +@Config(manifest = Config.NONE, application = Application::class, sdk = [31]) +class NetworkConnectionListenerTest { + + @Test + fun `default network capability changes notify listener`() { + val unavailableEvents = mutableListOf() + val listener = NetworkConnectionListener( + context = ApplicationProvider.getApplicationContext(), + onNetworkLost = { isNetworkUnavailable -> unavailableEvents += isNetworkUnavailable() }, + onProxySettingsChanged = {} + ) + val callback = ReflectionHelpers.getField(listener, "networkChangedCallback") + val network = mockk() + + callback.onCapabilitiesChanged(network, capabilities(hasInternet = true, validated = false)) + callback.onCapabilitiesChanged(network, capabilities(hasInternet = true, validated = false)) + callback.onCapabilitiesChanged(network, capabilities(hasInternet = true, validated = true)) + callback.onCapabilitiesChanged(network, capabilities(hasInternet = false, validated = false)) + + assertThat(unavailableEvents).containsExactly(false, false, true) + } + + private fun capabilities(hasInternet: Boolean, validated: Boolean): NetworkCapabilities { + val capabilities = NetworkCapabilities() + + if (hasInternet) { + capabilities.addCapabilityReflectively(NetworkCapabilities.NET_CAPABILITY_INTERNET) + } + + if (validated) { + capabilities.addCapabilityReflectively(NetworkCapabilities.NET_CAPABILITY_VALIDATED) + } + + return capabilities + } + + private fun NetworkCapabilities.addCapabilityReflectively(capability: Int) { + val method = NetworkCapabilities::class.java.getDeclaredMethod("addCapability", Int::class.javaPrimitiveType) + method.isAccessible = true + method.invoke(this, capability) + } +} diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 54fdcb6fe5..01099f23e3 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -461,7 +461,7 @@ class LibSignalChatConnection( }, onFailure = { throwable -> Log.w(TAG, "$name [sendKeepAlive] Failure:", throwable) - state.onNext(WebSocketConnectionState.DISCONNECTED) + disconnect() } ) } diff --git a/lib/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/lib/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index d89ce140e1..f786154a30 100644 --- a/lib/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/lib/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -259,12 +259,13 @@ class LibSignalChatConnectionTest { } // Test keepAlive that fails at the transport layer (send() throws), - // which transitions from CONNECTED -> DISCONNECTED. + // which disconnects the underlying chat connection before transitioning to DISCONNECTED. @Test fun keepAliveConnectionFailure() { val connectionFailure = RuntimeException("Sending keep-alive failed") val keepAliveFailureLatch = CountDownLatch(1) + disconnectLatch = CountDownLatch(1) every { chatConnection.send(any()) } answers { delay { @@ -281,15 +282,22 @@ class LibSignalChatConnectionTest { connection.sendKeepAlive() keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS) + disconnectLatch!!.await(100, TimeUnit.MILLISECONDS) + observer.awaitCount(3) observer.assertNotComplete() observer.assertValues( // We start in the connected state WebSocketConnectionState.CONNECTED, - // Disconnects as a result of keep-alive failure + // Starts an underlying disconnect as a result of keep-alive failure + WebSocketConnectionState.DISCONNECTING, + // Disconnects once libsignal confirms the connection was interrupted WebSocketConnectionState.DISCONNECTED ) observer.assertNoConsecutiveDuplicates() + verify(exactly = 1) { + chatConnection.disconnect() + } verify(exactly = 0) { healthMonitor.onKeepAliveResponse(any(), any()) healthMonitor.onMessageError(any(), any())