Improve WebSocket health monitoring.

This commit is contained in:
Cody Henthorne
2021-07-27 13:40:33 -04:00
committed by GitHub
parent fc6db45e59
commit 712b0c147a
20 changed files with 559 additions and 351 deletions
@@ -23,9 +23,7 @@ import org.whispersystems.signalservice.api.profiles.SignalServiceProfile;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
import org.whispersystems.signalservice.api.util.CredentialsProvider;
import org.whispersystems.signalservice.api.util.SleepTimer;
import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.signalservice.api.websocket.ConnectivityListener;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.push.PushServiceSocket;
import org.whispersystems.signalservice.internal.push.SignalServiceEnvelopeEntity;
@@ -50,16 +48,9 @@ import java.util.UUID;
*
* @author Moxie Marlinspike
*/
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class SignalServiceMessageReceiver {
private final PushServiceSocket socket;
private final SignalServiceConfiguration urls;
private final CredentialsProvider credentialsProvider;
private final String signalAgent;
private final ConnectivityListener connectivityListener;
private final SleepTimer sleepTimer;
private final ClientZkProfileOperations clientZkProfileOperations;
private final PushServiceSocket socket;
/**
* Construct a SignalServiceMessageReceiver.
@@ -70,18 +61,10 @@ public class SignalServiceMessageReceiver {
public SignalServiceMessageReceiver(SignalServiceConfiguration urls,
CredentialsProvider credentials,
String signalAgent,
ConnectivityListener listener,
SleepTimer timer,
ClientZkProfileOperations clientZkProfileOperations,
boolean automaticNetworkRetry)
{
this.urls = urls;
this.credentialsProvider = credentials;
this.socket = new PushServiceSocket(urls, credentials, signalAgent, clientZkProfileOperations, automaticNetworkRetry);
this.signalAgent = signalAgent;
this.connectivityListener = listener;
this.sleepTimer = timer;
this.clientZkProfileOperations = clientZkProfileOperations;
this.socket = new PushServiceSocket(urls, credentials, signalAgent, clientZkProfileOperations, automaticNetworkRetry);
}
/**
@@ -4,6 +4,7 @@ import org.whispersystems.libsignal.logging.Log;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope;
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
import org.whispersystems.signalservice.api.websocket.WebSocketFactory;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection;
@@ -15,7 +16,12 @@ import org.whispersystems.util.Base64;
import java.io.IOException;
import java.util.concurrent.TimeoutException;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.schedulers.Schedulers;
import io.reactivex.rxjava3.subjects.BehaviorSubject;
/**
* Provide a general interface to the WebSocket for making requests and reading messages sent by the server.
@@ -29,14 +35,43 @@ public final class SignalWebSocket {
private final WebSocketFactory webSocketFactory;
private WebSocketConnection webSocket;
private WebSocketConnection unidentifiedWebSocket;
private boolean canConnect;
private WebSocketConnection webSocket;
private final BehaviorSubject<WebSocketConnectionState> webSocketState;
private CompositeDisposable webSocketStateDisposable;
private WebSocketConnection unidentifiedWebSocket;
private final BehaviorSubject<WebSocketConnectionState> unidentifiedWebSocketState;
private CompositeDisposable unidentifiedWebSocketStateDisposable;
private boolean canConnect;
public SignalWebSocket(WebSocketFactory webSocketFactory) {
this.webSocketFactory = webSocketFactory;
this.webSocketFactory = webSocketFactory;
this.webSocketState = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED);
this.unidentifiedWebSocketState = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED);
this.webSocketStateDisposable = new CompositeDisposable();
this.unidentifiedWebSocketStateDisposable = new CompositeDisposable();
}
/**
* Get an observable stream of the identified WebSocket state. This observable is valid for the lifetime of
* the instance, and will update as WebSocketConnections are remade.
*/
public Observable<WebSocketConnectionState> getWebSocketState() {
return webSocketState;
}
/**
* Get an observable stream of the unidentified WebSocket state. This observable is valid for the lifetime of
* the instance, and will update as WebSocketConnections are remade.
*/
public Observable<WebSocketConnectionState> getUnidentifiedWebSocketState() {
return unidentifiedWebSocketState;
}
/**
* Indicate that WebSocketConnections can now be made and attempt to connect both of them.
*/
public synchronized void connect() {
canConnect = true;
try {
@@ -47,17 +82,54 @@ public final class SignalWebSocket {
}
}
/**
* Indicate that WebSocketConnections can no longer be made and disconnect both of them.
*/
public synchronized void disconnect() {
canConnect = false;
disconnectIdentified();
disconnectUnidentified();
}
/**
* Indicate that the current WebSocket instances need to be destroyed and new ones should be created the
* next time a connection is required. Intended to be used by the health monitor to cycle a WebSocket.
*/
public synchronized void forceNewWebSockets() {
Log.i(TAG, "Forcing new WebSockets " +
" identified: " + (webSocket != null ? webSocket.getName() : "[null]") +
" unidentified: " + (unidentifiedWebSocket != null ? unidentifiedWebSocket.getName() : "[null]") +
" canConnect: " + canConnect);
disconnectIdentified();
disconnectUnidentified();
}
private void disconnectIdentified() {
if (webSocket != null) {
webSocketStateDisposable.dispose();
webSocket.disconnect();
webSocket = null;
}
//noinspection ConstantConditions
if (!webSocketState.getValue().isFailure()) {
webSocketState.onNext(WebSocketConnectionState.DISCONNECTED);
}
}
}
private void disconnectUnidentified() {
if (unidentifiedWebSocket != null) {
unidentifiedWebSocketStateDisposable.dispose();
unidentifiedWebSocket.disconnect();
unidentifiedWebSocket = null;
//noinspection ConstantConditions
if (!unidentifiedWebSocketState.getValue().isFailure()) {
unidentifiedWebSocketState.onNext(WebSocketConnectionState.DISCONNECTED);
}
}
}
@@ -67,8 +139,16 @@ public final class SignalWebSocket {
}
if (webSocket == null || webSocket.isDead()) {
webSocket = webSocketFactory.createWebSocket();
webSocket.connect();
webSocketStateDisposable.dispose();
webSocket = webSocketFactory.createWebSocket();
webSocketStateDisposable = new CompositeDisposable();
Disposable state = webSocket.connect()
.subscribeOn(Schedulers.computation())
.observeOn(Schedulers.computation())
.subscribe(webSocketState::onNext);
webSocketStateDisposable.add(state);
}
return webSocket;
}
@@ -79,12 +159,34 @@ public final class SignalWebSocket {
}
if (unidentifiedWebSocket == null || unidentifiedWebSocket.isDead()) {
unidentifiedWebSocket = webSocketFactory.createUnidentifiedWebSocket();
unidentifiedWebSocket.connect();
unidentifiedWebSocketStateDisposable.dispose();
unidentifiedWebSocket = webSocketFactory.createUnidentifiedWebSocket();
unidentifiedWebSocketStateDisposable = new CompositeDisposable();
Disposable state = unidentifiedWebSocket.connect()
.subscribeOn(Schedulers.computation())
.observeOn(Schedulers.computation())
.subscribe(unidentifiedWebSocketState::onNext);
unidentifiedWebSocketStateDisposable.add(state);
}
return unidentifiedWebSocket;
}
/**
* Send keep-alive messages over both WebSocketConnections.
*/
public synchronized void sendKeepAlive() throws IOException {
if (canConnect) {
try {
getWebSocket().sendKeepAlive();
getUnidentifiedWebSocket().sendKeepAlive();
} catch (WebSocketUnavailableException e) {
throw new AssertionError(e);
}
}
}
public Single<WebsocketResponse> request(WebSocketRequestMessage requestMessage) {
try {
return getWebSocket().sendRequest(requestMessage);
@@ -98,20 +200,18 @@ public final class SignalWebSocket {
WebSocketRequestMessage message = WebSocketRequestMessage.newBuilder(requestMessage)
.addHeaders("Unidentified-Access-Key:" + Base64.encodeBytes(unidentifiedAccess.get().getUnidentifiedAccessKey()))
.build();
Single<WebsocketResponse> response;
try {
response = getUnidentifiedWebSocket().sendRequest(message);
return getUnidentifiedWebSocket().sendRequest(message)
.flatMap(r -> {
if (r.getStatus() == 401) {
return request(requestMessage);
}
return Single.just(r);
})
.onErrorResumeNext(t -> request(requestMessage));
} catch (IOException e) {
return Single.error(e);
}
return response.flatMap(r -> {
if (r.getStatus() == 401) {
return request(requestMessage);
}
return Single.just(r);
})
.onErrorResumeNext(t -> request(requestMessage));
} else {
return request(requestMessage);
}
@@ -1,12 +0,0 @@
package org.whispersystems.signalservice.api.websocket;
import okhttp3.Response;
public interface ConnectivityListener {
void onConnected();
void onConnecting();
void onDisconnected();
void onAuthenticationFailure();
boolean onGenericFailure(Response response, Throwable throwable);
}
@@ -0,0 +1,10 @@
package org.whispersystems.signalservice.api.websocket;
/**
* Callbacks to provide WebSocket health information to a monitor.
*/
public interface HealthMonitor {
void onKeepAliveResponse(long sentTimestamp, boolean isIdentifiedWebSocket);
void onMessageError(int status, boolean isIdentifiedWebSocket);
}
@@ -0,0 +1,18 @@
package org.whispersystems.signalservice.api.websocket;
/**
* Represent the state of a single WebSocketConnection.
*/
public enum WebSocketConnectionState {
DISCONNECTED,
CONNECTING,
CONNECTED,
RECONNECTING,
DISCONNECTING,
AUTHENTICATION_FAILED,
FAILED;
public boolean isFailure() {
return this == AUTHENTICATION_FAILED || this == FAILED;
}
}
@@ -7,10 +7,10 @@ import org.whispersystems.libsignal.util.Pair;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.push.TrustStore;
import org.whispersystems.signalservice.api.util.CredentialsProvider;
import org.whispersystems.signalservice.api.util.SleepTimer;
import org.whispersystems.signalservice.api.util.Tls12SocketFactory;
import org.whispersystems.signalservice.api.util.TlsProxySocketFactory;
import org.whispersystems.signalservice.api.websocket.ConnectivityListener;
import org.whispersystems.signalservice.api.websocket.HealthMonitor;
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
import org.whispersystems.signalservice.internal.configuration.SignalProxy;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager;
@@ -20,21 +20,24 @@ import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
import io.reactivex.rxjava3.subjects.BehaviorSubject;
import io.reactivex.rxjava3.subjects.SingleSubject;
import okhttp3.ConnectionSpec;
import okhttp3.Dns;
@@ -53,45 +56,40 @@ import static org.whispersystems.signalservice.internal.websocket.WebSocketProto
public class WebSocketConnection extends WebSocketListener {
private static final String TAG = WebSocketConnection.class.getSimpleName();
private static final int KEEPALIVE_TIMEOUT_SECONDS = 55;
public static final int KEEPALIVE_TIMEOUT_SECONDS = 55;
private final LinkedList<WebSocketRequestMessage> incomingRequests = new LinkedList<>();
private final Map<Long, OutgoingRequest> outgoingRequests = new HashMap<>();
private final Set<Long> keepAlives = new HashSet<>();
private final String name;
private final String wsUri;
private final TrustStore trustStore;
private final Optional<CredentialsProvider> credentialsProvider;
private final String signalAgent;
private ConnectivityListener listener;
private final SleepTimer sleepTimer;
private final List<Interceptor> interceptors;
private final Optional<Dns> dns;
private final Optional<SignalProxy> signalProxy;
private final String name;
private final String wsUri;
private final TrustStore trustStore;
private final Optional<CredentialsProvider> credentialsProvider;
private final String signalAgent;
private final HealthMonitor healthMonitor;
private final List<Interceptor> interceptors;
private final Optional<Dns> dns;
private final Optional<SignalProxy> signalProxy;
private final BehaviorSubject<WebSocketConnectionState> webSocketState;
private WebSocket client;
private KeepAliveSender keepAliveSender;
private int attempts;
private boolean connected;
private WebSocket client;
public WebSocketConnection(String name,
SignalServiceConfiguration serviceConfiguration,
Optional<CredentialsProvider> credentialsProvider,
String signalAgent,
ConnectivityListener listener,
SleepTimer timer)
HealthMonitor healthMonitor)
{
this.name = "[" + name + ":" + System.identityHashCode(this) + "]";
this.trustStore = serviceConfiguration.getSignalServiceUrls()[0].getTrustStore();
this.credentialsProvider = credentialsProvider;
this.signalAgent = signalAgent;
this.listener = listener;
this.sleepTimer = timer;
this.interceptors = serviceConfiguration.getNetworkInterceptors();
this.dns = serviceConfiguration.getDns();
this.signalProxy = serviceConfiguration.getSignalProxy();
this.attempts = 0;
this.connected = false;
this.healthMonitor = healthMonitor;
this.webSocketState = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED);
String uri = serviceConfiguration.getSignalServiceUrls()[0].getUrl().replace("https://", "wss://").replace("http://", "ws://");
@@ -102,7 +100,11 @@ public class WebSocketConnection extends WebSocketListener {
}
}
public synchronized void connect() {
public String getName() {
return name;
}
public synchronized Observable<WebSocketConnectionState> connect() {
log("connect()");
if (client == null) {
@@ -117,12 +119,12 @@ public class WebSocketConnection extends WebSocketListener {
Pair<SSLSocketFactory, X509TrustManager> socketFactory = createTlsSocketFactory(trustStore);
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
.sslSocketFactory(new Tls12SocketFactory(socketFactory.first()), socketFactory.second())
.connectionSpecs(Util.immutableList(ConnectionSpec.RESTRICTED_TLS))
.readTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS)
.dns(dns.or(Dns.SYSTEM))
.connectTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS);
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder().sslSocketFactory(new Tls12SocketFactory(socketFactory.first()),
socketFactory.second())
.connectionSpecs(Util.immutableList(ConnectionSpec.RESTRICTED_TLS))
.readTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS)
.dns(dns.or(Dns.SYSTEM))
.connectTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS);
for (Interceptor interceptor : interceptors) {
clientBuilder.addInterceptor(interceptor);
@@ -140,13 +142,11 @@ public class WebSocketConnection extends WebSocketListener {
requestBuilder.addHeader("X-Signal-Agent", signalAgent);
}
if (listener != null) {
listener.onConnecting();
}
webSocketState.onNext(WebSocketConnectionState.CONNECTING);
this.connected = false;
this.client = okHttpClient.newWebSocket(requestBuilder.build(), this);
this.client = okHttpClient.newWebSocket(requestBuilder.build(), this);
}
return webSocketState;
}
public synchronized boolean isDead() {
@@ -158,18 +158,8 @@ public class WebSocketConnection extends WebSocketListener {
if (client != null) {
client.close(1000, "OK");
client = null;
connected = false;
}
if (keepAliveSender != null) {
keepAliveSender.shutdown();
keepAliveSender = null;
}
if (listener != null) {
listener.onDisconnected();
listener = null;
client = null;
webSocketState.onNext(WebSocketConnectionState.DISCONNECTING);
}
notifyAll();
@@ -198,7 +188,7 @@ public class WebSocketConnection extends WebSocketListener {
}
public synchronized Single<WebsocketResponse> sendRequest(WebSocketRequestMessage request) throws IOException {
if (client == null || !connected) {
if (client == null) {
throw new IOException("No connection!");
}
@@ -235,17 +225,20 @@ public class WebSocketConnection extends WebSocketListener {
}
}
private synchronized void sendKeepAlive() throws IOException {
if (keepAliveSender != null && client != null) {
public synchronized void sendKeepAlive() throws IOException {
if (client != null) {
log( "Sending keep alive...");
long id = System.currentTimeMillis();
byte[] message = WebSocketMessage.newBuilder()
.setType(WebSocketMessage.Type.REQUEST)
.setRequest(WebSocketRequestMessage.newBuilder()
.setId(System.currentTimeMillis())
.setId(id)
.setPath("/v1/keepalive")
.setVerb("GET")
.build()).build()
.build())
.build()
.toByteArray();
keepAlives.add(id);
if (!client.send(ByteString.of(message))) {
throw new IOException("Write failed!");
}
@@ -254,16 +247,9 @@ public class WebSocketConnection extends WebSocketListener {
@Override
public synchronized void onOpen(WebSocket webSocket, Response response) {
if (client != null && keepAliveSender == null) {
if (client != null) {
log("onOpen() connected");
attempts = 0;
connected = true;
keepAliveSender = new KeepAliveSender();
keepAliveSender.start();
if (listener != null) {
listener.onConnected();
}
webSocketState.onNext(WebSocketConnectionState.CONNECTED);
}
}
@@ -280,6 +266,11 @@ public class WebSocketConnection extends WebSocketListener {
listener.onSuccess(new WebsocketResponse(message.getResponse().getStatus(),
new String(message.getResponse().getBody().toByteArray()),
message.getResponse().getHeadersList()));
if (message.getResponse().getStatus() >= 400) {
healthMonitor.onMessageError(message.getResponse().getStatus(), credentialsProvider.isPresent());
}
} else if (keepAlives.remove(message.getResponse().getId())) {
healthMonitor.onKeepAliveResponse(message.getResponse().getId(), credentialsProvider.isPresent());
}
}
@@ -292,34 +283,9 @@ public class WebSocketConnection extends WebSocketListener {
@Override
public synchronized void onClosed(WebSocket webSocket, int code, String reason) {
log("onClose()");
this.connected = false;
webSocketState.onNext(WebSocketConnectionState.DISCONNECTED);
Iterator<Map.Entry<Long, OutgoingRequest>> iterator = outgoingRequests.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<Long, OutgoingRequest> entry = iterator.next();
entry.getValue().onError(new IOException("Closed: " + code + ", " + reason));
iterator.remove();
}
if (keepAliveSender != null) {
keepAliveSender.shutdown();
keepAliveSender = null;
}
if (listener != null) {
listener.onDisconnected();
}
Util.wait(this, Math.min(++attempts * 200, TimeUnit.SECONDS.toMillis(15)));
if (client != null) {
log("Client not null when closed, attempting to reconnect");
client.close(1000, "OK");
client = null;
connected = false;
connect();
}
cleanupAfterShutdown();
notifyAll();
}
@@ -329,19 +295,29 @@ public class WebSocketConnection extends WebSocketListener {
warn("onFailure()", t);
if (response != null && (response.code() == 401 || response.code() == 403)) {
if (listener != null) {
listener.onAuthenticationFailure();
}
} else if (listener != null) {
boolean shouldRetryConnection = listener.onGenericFailure(response, t);
if (!shouldRetryConnection) {
warn("Experienced a failure, and the listener indicated we should not retry the connection. Disconnecting.");
disconnect();
}
webSocketState.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED);
} else {
webSocketState.onNext(WebSocketConnectionState.FAILED);
}
cleanupAfterShutdown();
notifyAll();
}
private void cleanupAfterShutdown() {
Iterator<Map.Entry<Long, OutgoingRequest>> iterator = outgoingRequests.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<Long, OutgoingRequest> entry = iterator.next();
entry.getValue().onError(new IOException("Closed unexpectedly"));
iterator.remove();
}
if (client != null) {
onClosed(webSocket, 1000, "OK");
log("Client not null when closed");
client.close(1000, "OK");
client = null;
}
}
@@ -353,6 +329,7 @@ public class WebSocketConnection extends WebSocketListener {
@Override
public synchronized void onClosing(WebSocket webSocket, int code, String reason) {
log("onClosing()");
webSocketState.onNext(WebSocketConnectionState.DISCONNECTING);
webSocket.close(1000, "OK");
}
@@ -390,30 +367,6 @@ public class WebSocketConnection extends WebSocketListener {
Log.w(TAG, name + " " + message, e);
}
private class KeepAliveSender extends Thread {
private final AtomicBoolean stop = new AtomicBoolean(false);
public void run() {
while (!stop.get()) {
try {
sleepTimer.sleep(TimeUnit.SECONDS.toMillis(KEEPALIVE_TIMEOUT_SECONDS));
if (!stop.get()) {
log("Sending keep alive...");
sendKeepAlive();
}
} catch (Throwable e) {
warn(e);
}
}
}
public void shutdown() {
stop.set(true);
}
}
private static class OutgoingRequest {
private final SingleSubject<WebsocketResponse> responseSingle;