Implement unauthenticated chat web socket connection via libsignal-net.

This commit is contained in:
moiseev-signal
2024-04-24 12:19:40 -07:00
committed by Greyson Parrelli
parent 00a91e32fc
commit 95fbd7a31c
13 changed files with 604 additions and 45 deletions

View File

@@ -0,0 +1,217 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal.websocket
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.schedulers.Schedulers
import io.reactivex.rxjava3.subjects.BehaviorSubject
import io.reactivex.rxjava3.subjects.SingleSubject
import org.signal.core.util.logging.Log
import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.ChatService
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.time.Instant
import java.util.Optional
import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
/**
* Implements the WebSocketConnection interface via libsignal-net
*
* Notable implementation choices:
* - [chatService] contains both the authenticated and unauthenticated connections,
* which one to use for [sendRequest]/[sendResponse] is based on [isAuthenticated].
* - keep-alive requests always use the [org.signal.libsignal.net.ChatService.unauthenticatedSendAndDebug]
* API, and log the debug info on success.
* - regular sends use [org.signal.libsignal.net.ChatService.unauthenticatedSend] and don't create any overhead.
* - [org.whispersystems.signalservice.api.websocket.WebSocketConnectionState] reporting is implemented
* as close as possible to the original implementation in
* [org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection].
*/
class LibSignalChatConnection(
name: String,
private val chatService: ChatService,
private val healthMonitor: HealthMonitor,
val isAuthenticated: Boolean
) : WebSocketConnection {
companion object {
private val TAG = Log.tag(LibSignalChatConnection::class.java)
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
private val KEEP_ALIVE_REQUEST = LibSignalRequest(
"GET",
"/v1/keepalive",
emptyMap(),
ByteArray(0),
SEND_TIMEOUT.toInt()
)
}
override val name = "[$name:${System.identityHashCode(this)}]"
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
override fun connect(): Observable<WebSocketConnectionState> {
Log.i(TAG, "$name Connecting...")
state.onNext(WebSocketConnectionState.CONNECTING)
val connect = if (isAuthenticated) {
chatService::connectAuthenticated
} else {
chatService::connectUnauthenticated
}
connect()
.whenComplete(
onSuccess = { debugInfo ->
Log.i(TAG, "$name Connected")
Log.d(TAG, "$name $debugInfo")
state.onNext(WebSocketConnectionState.CONNECTED)
},
onFailure = { throwable ->
// TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors
Log.d(TAG, "$name Connect failed", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
)
return state
}
override fun isDead(): Boolean = false
override fun disconnect() {
Log.i(TAG, "$name Disconnecting...")
state.onNext(WebSocketConnectionState.DISCONNECTING)
chatService.disconnect()
.whenComplete(
onSuccess = {
Log.i(TAG, "$name Disconnected")
state.onNext(WebSocketConnectionState.DISCONNECTED)
},
onFailure = { throwable ->
Log.d(TAG, "$name Disconnect failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
}
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest()
val send = if (isAuthenticated) {
throw NotImplementedError("Authenticated socket is not yet supported")
} else {
chatService::unauthenticatedSend
}
send(internalRequest)
.whenComplete(
onSuccess = { response ->
when (response!!.status) {
in 400..599 -> {
healthMonitor.onMessageError(response.status, false)
}
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = !isAuthenticated))
},
onFailure = { throwable ->
Log.i(TAG, "$name sendRequest failed", throwable)
single.onError(throwable)
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
override fun sendKeepAlive() {
Log.i(TAG, "$name Sending keep alive...")
val send = if (isAuthenticated) {
throw NotImplementedError("Authenticated socket is not yet supported")
} else {
chatService::unauthenticatedSendAndDebug
}
send(KEEP_ALIVE_REQUEST)
.whenComplete(
onSuccess = { debugResponse ->
Log.i(TAG, "$name Keep alive - success")
Log.d(TAG, "$name $debugResponse")
when (debugResponse!!.response.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value
false
)
}
in 400..599 -> {
healthMonitor.onMessageError(debugResponse.response.status, isAuthenticated)
}
else -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
}
}
},
onFailure = { throwable ->
Log.i(TAG, "$name Keep alive - failed")
Log.d(TAG, "$name $throwable")
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
}
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {
throw NotImplementedError()
}
override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage {
throw NotImplementedError()
}
override fun sendResponse(response: WebSocketResponseMessage?) {
throw NotImplementedError()
}
private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Long = SEND_TIMEOUT): LibSignalRequest {
return LibSignalRequest(
this.verb?.uppercase() ?: "GET",
this.path ?: "",
this.headers.associate {
val parts = it.split(':', limit = 2)
if (parts.size != 2) {
throw IllegalArgumentException("Headers must contain at least one colon")
}
parts[0] to parts[1]
},
this.body?.toByteArray() ?: byteArrayOf(),
timeout.toInt()
)
}
private fun LibSignalResponse.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse {
return WebsocketResponse(
this.status,
this.body.decodeToString(),
this.headers,
isUnidentified
)
}
private fun <T> CompletableFuture<T>.whenComplete(
onSuccess: ((T?) -> Unit),
onFailure: ((Throwable) -> Unit)
): CompletableFuture<T> {
return this.whenComplete { value, throwable ->
if (throwable != null) {
onFailure(throwable)
} else {
onSuccess(value)
}
}
}
}

View File

@@ -0,0 +1,23 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal.websocket
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.util.CredentialsProvider
/**
* Makes Network API more ergonomic to use with Android client types
*/
class LibSignalNetwork(private val inner: Network) {
fun createChatService(
credentialsProvider: CredentialsProvider? = null
): ChatService {
val username = credentialsProvider?.username ?: ""
val password = credentialsProvider?.password ?: ""
return inner.createChatService(username, password)
}
}

View File

@@ -2,7 +2,6 @@ package org.whispersystems.signalservice.internal.websocket;
import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.TrustStore;
import org.whispersystems.signalservice.api.util.CredentialsProvider;
import org.whispersystems.signalservice.api.util.Tls12SocketFactory;
@@ -25,7 +24,6 @@ import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@@ -51,10 +49,10 @@ import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
public class WebSocketConnection extends WebSocketListener {
public class OkHttpWebSocketConnection extends WebSocketListener implements WebSocketConnection {
private static final String TAG = WebSocketConnection.class.getSimpleName();
public static final int KEEPALIVE_FREQUENCY_SECONDS = 30;
private static final String TAG = OkHttpWebSocketConnection.class.getSimpleName();
public static final int KEEPALIVE_FREQUENCY_SECONDS = 30;
private final LinkedList<WebSocketRequestMessage> incomingRequests = new LinkedList<>();
private final Map<Long, OutgoingRequest> outgoingRequests = new HashMap<>();
@@ -76,22 +74,22 @@ public class WebSocketConnection extends WebSocketListener {
private WebSocket client;
public WebSocketConnection(String name,
SignalServiceConfiguration serviceConfiguration,
Optional<CredentialsProvider> credentialsProvider,
String signalAgent,
HealthMonitor healthMonitor,
boolean allowStories) {
public OkHttpWebSocketConnection(String name,
SignalServiceConfiguration serviceConfiguration,
Optional<CredentialsProvider> credentialsProvider,
String signalAgent,
HealthMonitor healthMonitor,
boolean allowStories) {
this(name, serviceConfiguration, credentialsProvider, signalAgent, healthMonitor, "", allowStories);
}
public WebSocketConnection(String name,
SignalServiceConfiguration serviceConfiguration,
Optional<CredentialsProvider> credentialsProvider,
String signalAgent,
HealthMonitor healthMonitor,
String extraPathUri,
boolean allowStories)
public OkHttpWebSocketConnection(String name,
SignalServiceConfiguration serviceConfiguration,
Optional<CredentialsProvider> credentialsProvider,
String signalAgent,
HealthMonitor healthMonitor,
String extraPathUri,
boolean allowStories)
{
this.name = "[" + name + ":" + System.identityHashCode(this) + "]";
this.trustStore = serviceConfiguration.getSignalServiceUrls()[0].getTrustStore();
@@ -108,6 +106,7 @@ public class WebSocketConnection extends WebSocketListener {
this.random = new SecureRandom();
}
@Override
public String getName() {
return name;
}
@@ -123,6 +122,7 @@ public class WebSocketConnection extends WebSocketListener {
}
}
@Override
public synchronized Observable<WebSocketConnectionState> connect() {
log("connect()");
@@ -130,7 +130,7 @@ public class WebSocketConnection extends WebSocketListener {
Pair<SignalServiceUrl, String> connectionInfo = getConnectionInfo();
SignalServiceUrl serviceUrl = connectionInfo.first();
String wsUri = connectionInfo.second();
String filledUri;
String filledUri;
if (credentialsProvider.isPresent()) {
filledUri = String.format(wsUri, credentialsProvider.get().getUsername(), credentialsProvider.get().getPassword());
@@ -177,10 +177,12 @@ public class WebSocketConnection extends WebSocketListener {
return webSocketState;
}
@Override
public synchronized boolean isDead() {
return client == null;
}
@Override
public synchronized void disconnect() {
log("disconnect()");
@@ -193,6 +195,7 @@ public class WebSocketConnection extends WebSocketListener {
notifyAll();
}
@Override
public synchronized Optional<WebSocketRequestMessage> readRequestIfAvailable() {
if (incomingRequests.size() > 0) {
return Optional.of(incomingRequests.removeFirst());
@@ -201,6 +204,7 @@ public class WebSocketConnection extends WebSocketListener {
}
}
@Override
public synchronized WebSocketRequestMessage readRequest(long timeoutMillis)
throws TimeoutException, IOException
{
@@ -223,6 +227,7 @@ public class WebSocketConnection extends WebSocketListener {
}
}
@Override
public synchronized Single<WebsocketResponse> sendRequest(WebSocketRequestMessage request) throws IOException {
if (client == null) {
throw new IOException("No connection!");
@@ -246,6 +251,7 @@ public class WebSocketConnection extends WebSocketListener {
.timeout(10, TimeUnit.SECONDS, Schedulers.io());
}
@Override
public synchronized void sendResponse(WebSocketResponseMessage response) throws IOException {
if (client == null) {
throw new IOException("Connection closed!");
@@ -261,9 +267,10 @@ public class WebSocketConnection extends WebSocketListener {
}
}
@Override
public synchronized void sendKeepAlive() throws IOException {
if (client != null) {
log( "Sending keep alive...");
log("Sending keep alive...");
long id = System.currentTimeMillis();
byte[] message = new WebSocketMessage.Builder()
.type(WebSocketMessage.Type.REQUEST)

View File

@@ -0,0 +1,39 @@
package org.whispersystems.signalservice.internal.websocket
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.io.IOException
import java.util.Optional
import java.util.concurrent.TimeoutException
/**
* Common interface for the web socket connection API
*
* At the time of this writing there are two implementations available:
* - OkHttpWebSocketConnection - the original Android client implementation in Java using OkHttp library
* - LibSignalChatConnection - the wrapper around libsignal's [org.signal.libsignal.net.ChatService]
*/
interface WebSocketConnection {
val name: String
fun connect(): Observable<WebSocketConnectionState>
fun isDead(): Boolean
fun disconnect()
@Throws(IOException::class)
fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse>
@Throws(IOException::class)
fun sendKeepAlive()
fun readRequestIfAvailable(): Optional<WebSocketRequestMessage>
@Throws(TimeoutException::class, IOException::class)
fun readRequest(timeoutMillis: Long): WebSocketRequestMessage
@Throws(IOException::class)
fun sendResponse(response: WebSocketResponseMessage?)
}

View File

@@ -15,9 +15,13 @@ public class WebsocketResponse {
private final boolean unidentified;
WebsocketResponse(int status, String body, List<String> headers, boolean unidentified) {
this(status, body, parseHeaders(headers), unidentified);
}
WebsocketResponse(int status, String body, Map<String, String> headerMap, boolean unidentified) {
this.status = status;
this.body = body;
this.headers = parseHeaders(headers);
this.headers = headerMap;
this.unidentified = unidentified;
}
@@ -41,7 +45,7 @@ public class WebsocketResponse {
Map<String, String> headers = new HashMap<>(rawHeaders.size());
for (String raw : rawHeaders) {
if (raw != null && raw.length() > 0) {
if (raw != null && !raw.isEmpty()) {
int colonIndex = raw.indexOf(":");
if (colonIndex > 0 && colonIndex < raw.length() - 1) {