From 7f2b6178d5a8a1d5eb17e8cd406ae806ea174575 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Mon, 1 Feb 2021 18:52:01 -0800 Subject: [PATCH] Add support for configuring a signal proxy. --- .../push/SignalServiceNetworkAccess.java | 5 + .../api/SignalServiceMessageReceiver.java | 6 +- .../api/util/TlsProxySocketFactory.java | 306 ++++++++++++++++++ .../internal/configuration/SignalProxy.java | 19 ++ .../SignalServiceConfiguration.java | 7 + .../internal/push/PushServiceSocket.java | 35 +- .../websocket/WebSocketConnection.java | 11 +- 7 files changed, 373 insertions(+), 16 deletions(-) create mode 100644 libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java create mode 100644 libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalProxy.java diff --git a/app/src/main/java/org/thoughtcrime/securesms/push/SignalServiceNetworkAccess.java b/app/src/main/java/org/thoughtcrime/securesms/push/SignalServiceNetworkAccess.java index 09b1e781ba..4826e2cea1 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/push/SignalServiceNetworkAccess.java +++ b/app/src/main/java/org/thoughtcrime/securesms/push/SignalServiceNetworkAccess.java @@ -185,6 +185,7 @@ public class SignalServiceNetworkAccess { new SignalStorageUrl[] {egyptGoogleStorage, baseGoogleStorage, baseAndroidStorage, mapsOneAndroidStorage, mapsTwoAndroidStorage, mailAndroidStorage}, interceptors, dns, + Optional.absent(), zkGroupServerPublicParams)); put(COUNTRY_CODE_UAE, new SignalServiceConfiguration(new SignalServiceUrl[] {uaeGoogleService, baseAndroidService, baseGoogleService, mapsOneAndroidService, mapsTwoAndroidService, mailAndroidService}, @@ -195,6 +196,7 @@ public class SignalServiceNetworkAccess { new SignalStorageUrl[] {uaeGoogleStorage, baseGoogleStorage, baseAndroidStorage, mapsOneAndroidStorage, mapsTwoAndroidStorage, mailAndroidStorage}, interceptors, dns, + Optional.absent(), zkGroupServerPublicParams)); put(COUNTRY_CODE_OMAN, new SignalServiceConfiguration(new SignalServiceUrl[] {omanGoogleService, baseAndroidService, baseGoogleService, mapsOneAndroidService, mapsTwoAndroidService, mailAndroidService}, @@ -205,6 +207,7 @@ public class SignalServiceNetworkAccess { new SignalStorageUrl[] {omanGoogleStorage, baseGoogleStorage, baseAndroidStorage, mapsOneAndroidStorage, mapsTwoAndroidStorage, mailAndroidStorage}, interceptors, dns, + Optional.absent(), zkGroupServerPublicParams)); @@ -216,6 +219,7 @@ public class SignalServiceNetworkAccess { new SignalStorageUrl[] {qatarGoogleStorage, baseGoogleStorage, baseAndroidStorage, mapsOneAndroidStorage, mapsTwoAndroidStorage, mailAndroidStorage}, interceptors, dns, + Optional.absent(), zkGroupServerPublicParams)); }}; @@ -227,6 +231,7 @@ public class SignalServiceNetworkAccess { new SignalStorageUrl[] {new SignalStorageUrl(BuildConfig.STORAGE_URL, new SignalServiceTrustStore(context))}, interceptors, dns, + Optional.absent(), zkGroupServerPublicParams); this.censoredCountries = this.censorshipConfiguration.keySet().toArray(new String[0]); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 2ab3478cda..fe2ef19119 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -243,7 +243,8 @@ public class SignalServiceMessageReceiver { Optional.of(credentialsProvider), signalAgent, connectivityListener, sleepTimer, urls.getNetworkInterceptors(), - urls.getDns()); + urls.getDns(), + urls.getSignalProxy()); return new SignalServiceMessagePipe(webSocket, Optional.of(credentialsProvider), clientZkProfileOperations); } @@ -254,7 +255,8 @@ public class SignalServiceMessageReceiver { Optional.absent(), signalAgent, connectivityListener, sleepTimer, urls.getNetworkInterceptors(), - urls.getDns()); + urls.getDns(), + urls.getSignalProxy()); return new SignalServiceMessagePipe(webSocket, Optional.of(credentialsProvider), clientZkProfileOperations); } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java new file mode 100644 index 0000000000..a5eb3f792a --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java @@ -0,0 +1,306 @@ +package org.whispersystems.signalservice.api.util; + +import org.whispersystems.libsignal.util.guava.Optional; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.SocketOption; +import java.net.UnknownHostException; +import java.nio.channels.SocketChannel; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.util.List; +import java.util.Set; + +import javax.net.SocketFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.Dns; + +public class TlsProxySocketFactory extends SocketFactory { + + private final SSLSocketFactory system; + + private final String proxyHost; + private final int proxyPort; + private final Optional dns; + + public TlsProxySocketFactory(String proxyHost, int proxyPort, Optional dns) { + try { + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, null, null); + + this.system = context.getSocketFactory(); + this.proxyHost = proxyHost; + this.proxyPort = proxyPort; + this.dns = dns; + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new AssertionError(e); + } + } + + @Override + public Socket createSocket(String host, int port) throws IOException, UnknownHostException { + if (dns.isPresent()) { + List resolved = dns.get().lookup(host); + + if (resolved.size() > 0) { + return createSocket(resolved.get(0), port); + } + } + + return new ProxySocket(system.createSocket(proxyHost, proxyPort)); + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException { + if (dns.isPresent()) { + List resolved = dns.get().lookup(host); + + if (resolved.size() > 0) { + return createSocket(resolved.get(0), port, localHost, localPort); + } + } + + return new ProxySocket(system.createSocket(proxyHost, proxyPort, localHost, localPort)); + } + + @Override + public Socket createSocket(InetAddress host, int port) throws IOException { + return new ProxySocket(system.createSocket(proxyHost, proxyPort)); + } + + @Override + public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException { + return new ProxySocket(system.createSocket(proxyHost, proxyPort, localAddress, localPort)); + } + + @Override + public Socket createSocket() throws IOException { + SSLSocket socket = (SSLSocket)system.createSocket(proxyHost, proxyPort); + socket.startHandshake(); + + return new ProxySocket(socket); + } + + private static class ProxySocket extends Socket { + + private final Socket delegate; + + private ProxySocket(Socket delegate) { + this.delegate = delegate; + } + + @Override + public void bind(SocketAddress bindpoint) throws IOException { + delegate.bind(bindpoint); + } + + @Override + public InetAddress getInetAddress() { + return delegate.getInetAddress(); + } + + @Override + public InetAddress getLocalAddress() { + return delegate.getLocalAddress(); + } + + @Override + public int getPort() { + return delegate.getPort(); + } + + @Override + public int getLocalPort() { + return delegate.getLocalPort(); + } + + @Override + public SocketAddress getRemoteSocketAddress() { + return delegate.getRemoteSocketAddress(); + } + + @Override + public SocketAddress getLocalSocketAddress() { + return delegate.getLocalSocketAddress(); + } + + @Override + public SocketChannel getChannel() { + return delegate.getChannel(); + } + + @Override + public InputStream getInputStream() throws IOException { + return delegate.getInputStream(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + return delegate.getOutputStream(); + } + + @Override + public void setTcpNoDelay(boolean on) throws SocketException { + delegate.setTcpNoDelay(on); + } + + @Override + public boolean getTcpNoDelay() throws SocketException { + return delegate.getTcpNoDelay(); + } + + @Override + public void setSoLinger(boolean on, int linger) throws SocketException { + delegate.setSoLinger(on, linger); + } + + @Override + public int getSoLinger() throws SocketException { + return delegate.getSoLinger(); + } + + @Override + public void sendUrgentData(int data) throws IOException { + delegate.sendUrgentData(data); + } + + @Override + public void setOOBInline(boolean on) throws SocketException { + delegate.setOOBInline(on); + } + + @Override + public boolean getOOBInline() throws SocketException { + return delegate.getOOBInline(); + } + + @Override + public void setSoTimeout(int timeout) throws SocketException { + delegate.setSoTimeout(timeout); + } + + @Override + public int getSoTimeout() throws SocketException { + return delegate.getSoTimeout(); + } + + @Override + public void setSendBufferSize(int size) throws SocketException { + delegate.setSendBufferSize(size); + } + + @Override + public int getSendBufferSize() throws SocketException { + return delegate.getSendBufferSize(); + } + + @Override + public void setReceiveBufferSize(int size) throws SocketException { + delegate.setReceiveBufferSize(size); + } + + @Override + public int getReceiveBufferSize() throws SocketException { + return delegate.getReceiveBufferSize(); + } + + @Override + public void setKeepAlive(boolean on) throws SocketException { + delegate.setKeepAlive(on); + } + + @Override + public boolean getKeepAlive() throws SocketException { + return delegate.getKeepAlive(); + } + + @Override + public void setTrafficClass(int tc) throws SocketException { + delegate.setTrafficClass(tc); + } + + @Override + public int getTrafficClass() throws SocketException { + return delegate.getTrafficClass(); + } + + @Override + public void setReuseAddress(boolean on) throws SocketException { + delegate.setReuseAddress(on); + } + + @Override + public boolean getReuseAddress() throws SocketException { + return delegate.getReuseAddress(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public void shutdownInput() throws IOException { + delegate.shutdownInput(); + } + + @Override + public void shutdownOutput() throws IOException { + delegate.shutdownOutput(); + } + + @Override + public String toString() { + return delegate.toString(); + } + + @Override + public boolean isConnected() { + return delegate.isConnected(); + } + + @Override + public boolean isBound() { + return delegate.isBound(); + } + + @Override + public boolean isClosed() { + return delegate.isClosed(); + } + + @Override + public boolean isInputShutdown() { + return delegate.isInputShutdown(); + } + + @Override + public boolean isOutputShutdown() { + return delegate.isOutputShutdown(); + } + + @Override + public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + delegate.setPerformancePreferences(connectionTime, latency, bandwidth); + } + + @Override + public void connect(SocketAddress endpoint) throws IOException { + // Already connected + } + + @Override + public void connect(SocketAddress endpoint, int timeout) throws IOException { + // Already connected + } + } +} diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalProxy.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalProxy.java new file mode 100644 index 0000000000..d855689b0e --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalProxy.java @@ -0,0 +1,19 @@ +package org.whispersystems.signalservice.internal.configuration; + +public class SignalProxy { + private final String host; + private final int port; + + public SignalProxy(String host, int port) { + this.host = host; + this.port = port; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } +} diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalServiceConfiguration.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalServiceConfiguration.java index 44b065721c..85f91fbfc5 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalServiceConfiguration.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/configuration/SignalServiceConfiguration.java @@ -17,6 +17,7 @@ public final class SignalServiceConfiguration { private final SignalStorageUrl[] signalStorageUrls; private final List networkInterceptors; private final Optional dns; + private final Optional proxy; private final byte[] zkGroupServerPublicParams; public SignalServiceConfiguration(SignalServiceUrl[] signalServiceUrls, @@ -26,6 +27,7 @@ public final class SignalServiceConfiguration { SignalStorageUrl[] signalStorageUrls, List networkInterceptors, Optional dns, + Optional proxy, byte[] zkGroupServerPublicParams) { this.signalServiceUrls = signalServiceUrls; @@ -35,6 +37,7 @@ public final class SignalServiceConfiguration { this.signalStorageUrls = signalStorageUrls; this.networkInterceptors = networkInterceptors; this.dns = dns; + this.proxy = proxy; this.zkGroupServerPublicParams = zkGroupServerPublicParams; } @@ -69,4 +72,8 @@ public final class SignalServiceConfiguration { public byte[] getZkGroupServerPublicParams() { return zkGroupServerPublicParams; } + + public Optional getSignalProxy() { + return proxy; + } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java index 22ca7fec85..79d5a6dcab 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java @@ -71,8 +71,10 @@ import org.whispersystems.signalservice.api.push.exceptions.UsernameTakenExcepti import org.whispersystems.signalservice.api.storage.StorageAuthResponse; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.Tls12SocketFactory; +import org.whispersystems.signalservice.api.util.TlsProxySocketFactory; import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.internal.configuration.SignalCdnUrl; +import org.whispersystems.signalservice.internal.configuration.SignalProxy; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.configuration.SignalUrl; import org.whispersystems.signalservice.internal.contacts.entities.DiscoveryRequest; @@ -129,6 +131,7 @@ import java.util.UUID; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import javax.net.SocketFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; @@ -245,11 +248,11 @@ public class PushServiceSocket { this.credentialsProvider = credentialsProvider; this.signalAgent = signalAgent; this.automaticNetworkRetry = automaticNetworkRetry; - this.serviceClients = createServiceConnectionHolders(configuration.getSignalServiceUrls(), configuration.getNetworkInterceptors(), configuration.getDns()); - this.cdnClientsMap = createCdnClientsMap(configuration.getSignalCdnUrlMap(), configuration.getNetworkInterceptors(), configuration.getDns()); - this.contactDiscoveryClients = createConnectionHolders(configuration.getSignalContactDiscoveryUrls(), configuration.getNetworkInterceptors(), configuration.getDns()); - this.keyBackupServiceClients = createConnectionHolders(configuration.getSignalKeyBackupServiceUrls(), configuration.getNetworkInterceptors(), configuration.getDns()); - this.storageClients = createConnectionHolders(configuration.getSignalStorageUrls(), configuration.getNetworkInterceptors(), configuration.getDns()); + this.serviceClients = createServiceConnectionHolders(configuration.getSignalServiceUrls(), configuration.getNetworkInterceptors(), configuration.getDns(), configuration.getSignalProxy()); + this.cdnClientsMap = createCdnClientsMap(configuration.getSignalCdnUrlMap(), configuration.getNetworkInterceptors(), configuration.getDns(), configuration.getSignalProxy()); + this.contactDiscoveryClients = createConnectionHolders(configuration.getSignalContactDiscoveryUrls(), configuration.getNetworkInterceptors(), configuration.getDns(), configuration.getSignalProxy()); + this.keyBackupServiceClients = createConnectionHolders(configuration.getSignalKeyBackupServiceUrls(), configuration.getNetworkInterceptors(), configuration.getDns(), configuration.getSignalProxy()); + this.storageClients = createConnectionHolders(configuration.getSignalStorageUrls(), configuration.getNetworkInterceptors(), configuration.getDns(), configuration.getSignalProxy()); this.random = new SecureRandom(); this.clientZkProfileOperations = clientZkProfileOperations; } @@ -1741,13 +1744,14 @@ public class PushServiceSocket { private ServiceConnectionHolder[] createServiceConnectionHolders(SignalUrl[] urls, List interceptors, - Optional dns) + Optional dns, + Optional proxy) { List serviceConnectionHolders = new LinkedList<>(); for (SignalUrl url : urls) { - serviceConnectionHolders.add(new ServiceConnectionHolder(createConnectionClient(url, interceptors, dns), - createConnectionClient(url, interceptors, dns), + serviceConnectionHolders.add(new ServiceConnectionHolder(createConnectionClient(url, interceptors, dns, proxy), + createConnectionClient(url, interceptors, dns, proxy), url.getUrl(), url.getHostHeader())); } @@ -1756,12 +1760,13 @@ public class PushServiceSocket { private static Map createCdnClientsMap(final Map signalCdnUrlMap, final List interceptors, - final Optional dns) { + final Optional dns, + final Optional proxy) { validateConfiguration(signalCdnUrlMap); final Map result = new HashMap<>(); for (Map.Entry entry : signalCdnUrlMap.entrySet()) { result.put(entry.getKey(), - createConnectionHolders(entry.getValue(), interceptors, dns)); + createConnectionHolders(entry.getValue(), interceptors, dns, proxy)); } return Collections.unmodifiableMap(result); } @@ -1772,17 +1777,17 @@ public class PushServiceSocket { } } - private static ConnectionHolder[] createConnectionHolders(SignalUrl[] urls, List interceptors, Optional dns) { + private static ConnectionHolder[] createConnectionHolders(SignalUrl[] urls, List interceptors, Optional dns, Optional proxy) { List connectionHolders = new LinkedList<>(); for (SignalUrl url : urls) { - connectionHolders.add(new ConnectionHolder(createConnectionClient(url, interceptors, dns), url.getUrl(), url.getHostHeader())); + connectionHolders.add(new ConnectionHolder(createConnectionClient(url, interceptors, dns, proxy), url.getUrl(), url.getHostHeader())); } return connectionHolders.toArray(new ConnectionHolder[0]); } - private static OkHttpClient createConnectionClient(SignalUrl url, List interceptors, Optional dns) { + private static OkHttpClient createConnectionClient(SignalUrl url, List interceptors, Optional dns, Optional proxy) { try { TrustManager[] trustManagers = BlacklistingTrustManager.createFor(url.getTrustStore()); @@ -1794,6 +1799,10 @@ public class PushServiceSocket { .connectionSpecs(url.getConnectionSpecs().or(Util.immutableList(ConnectionSpec.RESTRICTED_TLS))) .dns(dns.or(Dns.SYSTEM)); + if (proxy.isPresent()) { + builder.socketFactory(new TlsProxySocketFactory(proxy.get().getHost(), proxy.get().getPort(), dns)); + } + builder.sslSocketFactory(new Tls12SocketFactory(context.getSocketFactory()), (X509TrustManager)trustManagers[0]) .connectionSpecs(url.getConnectionSpecs().or(Util.immutableList(ConnectionSpec.RESTRICTED_TLS))) .build(); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java index 9b446216d2..afb454040e 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java @@ -9,7 +9,9 @@ import org.whispersystems.signalservice.api.push.TrustStore; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.Tls12SocketFactory; +import org.whispersystems.signalservice.api.util.TlsProxySocketFactory; import org.whispersystems.signalservice.api.websocket.ConnectivityListener; +import org.whispersystems.signalservice.internal.configuration.SignalProxy; import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager; import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.concurrent.ListenableFuture; @@ -63,6 +65,7 @@ public class WebSocketConnection extends WebSocketListener { private final SleepTimer sleepTimer; private final List interceptors; private final Optional dns; + private final Optional signalProxy; private WebSocket client; private KeepAliveSender keepAliveSender; @@ -76,7 +79,8 @@ public class WebSocketConnection extends WebSocketListener { ConnectivityListener listener, SleepTimer timer, List interceptors, - Optional dns) + Optional dns, + Optional signalProxy) { this.trustStore = trustStore; this.credentialsProvider = credentialsProvider; @@ -85,6 +89,7 @@ public class WebSocketConnection extends WebSocketListener { this.sleepTimer = timer; this.interceptors = interceptors; this.dns = dns; + this.signalProxy = signalProxy; this.attempts = 0; this.connected = false; @@ -120,6 +125,10 @@ public class WebSocketConnection extends WebSocketListener { clientBuilder.addInterceptor(interceptor); } + if (signalProxy.isPresent()) { + clientBuilder.socketFactory(new TlsProxySocketFactory(signalProxy.get().getHost(), signalProxy.get().getPort(), dns)); + } + OkHttpClient okHttpClient = clientBuilder.build(); Request.Builder requestBuilder = new Request.Builder().url(filledUri);