diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java deleted file mode 100644 index 4aaa445ab9..0000000000 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.java +++ /dev/null @@ -1,305 +0,0 @@ -package org.whispersystems.signalservice.api.util; - - - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.net.InetAddress; -import java.net.Socket; -import java.net.SocketAddress; -import java.net.SocketException; -import java.net.UnknownHostException; -import java.nio.channels.SocketChannel; -import java.security.KeyManagementException; -import java.security.NoSuchAlgorithmException; -import java.util.List; -import java.util.Optional; - -import javax.net.SocketFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSocket; -import javax.net.ssl.SSLSocketFactory; - -import okhttp3.Dns; - -public class TlsProxySocketFactory extends SocketFactory { - - private final SSLSocketFactory system; - - private final String proxyHost; - private final int proxyPort; - private final Optional dns; - - public TlsProxySocketFactory(String proxyHost, int proxyPort, Optional dns) { - try { - SSLContext context = SSLContext.getInstance("TLS"); - context.init(null, null, null); - - this.system = context.getSocketFactory(); - this.proxyHost = proxyHost; - this.proxyPort = proxyPort; - this.dns = dns; - } catch (NoSuchAlgorithmException | KeyManagementException e) { - throw new AssertionError(e); - } - } - - @Override - public Socket createSocket(String host, int port) throws IOException, UnknownHostException { - if (dns.isPresent()) { - List resolved = dns.get().lookup(host); - - if (resolved.size() > 0) { - return createSocket(resolved.get(0), port); - } - } - - return new ProxySocket(system.createSocket(proxyHost, proxyPort)); - } - - @Override - public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException { - if (dns.isPresent()) { - List resolved = dns.get().lookup(host); - - if (resolved.size() > 0) { - return createSocket(resolved.get(0), port, localHost, localPort); - } - } - - return new ProxySocket(system.createSocket(proxyHost, proxyPort, localHost, localPort)); - } - - @Override - public Socket createSocket(InetAddress host, int port) throws IOException { - return new ProxySocket(system.createSocket(proxyHost, proxyPort)); - } - - @Override - public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException { - return new ProxySocket(system.createSocket(proxyHost, proxyPort, localAddress, localPort)); - } - - @Override - public Socket createSocket() throws IOException { - SSLSocket socket = (SSLSocket)system.createSocket(proxyHost, proxyPort); - socket.startHandshake(); - - return new ProxySocket(socket); - } - - private static class ProxySocket extends Socket { - - private final Socket delegate; - - private ProxySocket(Socket delegate) { - this.delegate = delegate; - } - - @Override - public void bind(SocketAddress bindpoint) throws IOException { - delegate.bind(bindpoint); - } - - @Override - public InetAddress getInetAddress() { - return delegate.getInetAddress(); - } - - @Override - public InetAddress getLocalAddress() { - return delegate.getLocalAddress(); - } - - @Override - public int getPort() { - return delegate.getPort(); - } - - @Override - public int getLocalPort() { - return delegate.getLocalPort(); - } - - @Override - public SocketAddress getRemoteSocketAddress() { - return delegate.getRemoteSocketAddress(); - } - - @Override - public SocketAddress getLocalSocketAddress() { - return delegate.getLocalSocketAddress(); - } - - @Override - public SocketChannel getChannel() { - return delegate.getChannel(); - } - - @Override - public InputStream getInputStream() throws IOException { - return delegate.getInputStream(); - } - - @Override - public OutputStream getOutputStream() throws IOException { - return delegate.getOutputStream(); - } - - @Override - public void setTcpNoDelay(boolean on) throws SocketException { - delegate.setTcpNoDelay(on); - } - - @Override - public boolean getTcpNoDelay() throws SocketException { - return delegate.getTcpNoDelay(); - } - - @Override - public void setSoLinger(boolean on, int linger) throws SocketException { - delegate.setSoLinger(on, linger); - } - - @Override - public int getSoLinger() throws SocketException { - return delegate.getSoLinger(); - } - - @Override - public void sendUrgentData(int data) throws IOException { - delegate.sendUrgentData(data); - } - - @Override - public void setOOBInline(boolean on) throws SocketException { - delegate.setOOBInline(on); - } - - @Override - public boolean getOOBInline() throws SocketException { - return delegate.getOOBInline(); - } - - @Override - public void setSoTimeout(int timeout) throws SocketException { - delegate.setSoTimeout(timeout); - } - - @Override - public int getSoTimeout() throws SocketException { - return delegate.getSoTimeout(); - } - - @Override - public void setSendBufferSize(int size) throws SocketException { - delegate.setSendBufferSize(size); - } - - @Override - public int getSendBufferSize() throws SocketException { - return delegate.getSendBufferSize(); - } - - @Override - public void setReceiveBufferSize(int size) throws SocketException { - delegate.setReceiveBufferSize(size); - } - - @Override - public int getReceiveBufferSize() throws SocketException { - return delegate.getReceiveBufferSize(); - } - - @Override - public void setKeepAlive(boolean on) throws SocketException { - delegate.setKeepAlive(on); - } - - @Override - public boolean getKeepAlive() throws SocketException { - return delegate.getKeepAlive(); - } - - @Override - public void setTrafficClass(int tc) throws SocketException { - delegate.setTrafficClass(tc); - } - - @Override - public int getTrafficClass() throws SocketException { - return delegate.getTrafficClass(); - } - - @Override - public void setReuseAddress(boolean on) throws SocketException { - delegate.setReuseAddress(on); - } - - @Override - public boolean getReuseAddress() throws SocketException { - return delegate.getReuseAddress(); - } - - @Override - public void close() throws IOException { - delegate.close(); - } - - @Override - public void shutdownInput() throws IOException { - delegate.shutdownInput(); - } - - @Override - public void shutdownOutput() throws IOException { - delegate.shutdownOutput(); - } - - @Override - public String toString() { - return delegate.toString(); - } - - @Override - public boolean isConnected() { - return delegate.isConnected(); - } - - @Override - public boolean isBound() { - return delegate.isBound(); - } - - @Override - public boolean isClosed() { - return delegate.isClosed(); - } - - @Override - public boolean isInputShutdown() { - return delegate.isInputShutdown(); - } - - @Override - public boolean isOutputShutdown() { - return delegate.isOutputShutdown(); - } - - @Override - public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) { - delegate.setPerformancePreferences(connectionTime, latency, bandwidth); - } - - @Override - public void connect(SocketAddress endpoint) throws IOException { - // Already connected - } - - @Override - public void connect(SocketAddress endpoint, int timeout) throws IOException { - // Already connected - } - } -} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.kt new file mode 100644 index 0000000000..05c00e66c6 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/util/TlsProxySocketFactory.kt @@ -0,0 +1,274 @@ +package org.whispersystems.signalservice.api.util + +import okhttp3.Dns +import java.io.IOException +import java.io.InputStream +import java.io.OutputStream +import java.net.InetAddress +import java.net.Socket +import java.net.SocketAddress +import java.net.SocketException +import java.net.UnknownHostException +import java.nio.channels.SocketChannel +import java.security.KeyManagementException +import java.security.NoSuchAlgorithmException +import java.util.Optional +import javax.net.SocketFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLSocket +import javax.net.ssl.SSLSocketFactory + +class TlsProxySocketFactory( + private val proxyHost: String?, + private val proxyPort: Int, + private val dns: Optional +) : SocketFactory() { + + private val system: SSLSocketFactory = try { + val context = SSLContext.getInstance("TLS") + context.init(null, null, null) + + context.socketFactory + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) + } catch (e: KeyManagementException) { + throw AssertionError(e) + } + + @Throws(IOException::class, UnknownHostException::class) + override fun createSocket(host: String, port: Int): Socket { + if (dns.isPresent) { + val resolved = dns.get().lookup(host) + + if (resolved.isNotEmpty()) { + return createSocket(resolved[0], port) + } + } + + return ProxySocket(system.createSocket(proxyHost, proxyPort)) + } + + @Throws(IOException::class, UnknownHostException::class) + override fun createSocket(host: String, port: Int, localHost: InetAddress, localPort: Int): Socket { + if (dns.isPresent) { + val resolved = dns.get().lookup(host) + + if (resolved.isNotEmpty()) { + return createSocket(resolved[0], port, localHost, localPort) + } + } + + return ProxySocket(system.createSocket(proxyHost, proxyPort, localHost, localPort)) + } + + @Throws(IOException::class) + override fun createSocket(host: InetAddress, port: Int): Socket { + return ProxySocket(system.createSocket(proxyHost, proxyPort)) + } + + @Throws(IOException::class) + override fun createSocket(address: InetAddress, port: Int, localAddress: InetAddress, localPort: Int): Socket { + return ProxySocket(system.createSocket(proxyHost, proxyPort, localAddress, localPort)) + } + + @Throws(IOException::class) + override fun createSocket(): Socket { + val socket = system.createSocket(proxyHost, proxyPort) as SSLSocket + socket.startHandshake() + + return ProxySocket(socket) + } + + private class ProxySocket(private val delegate: Socket) : Socket() { + @Throws(IOException::class) + override fun connect(endpoint: SocketAddress) { + // Already connected + } + + @Throws(IOException::class) + override fun connect(endpoint: SocketAddress, timeout: Int) { + // Already connected + } + + @Throws(IOException::class) + override fun bind(bindpoint: SocketAddress) { + delegate.bind(bindpoint) + } + + override fun getInetAddress(): InetAddress { + return delegate.inetAddress + } + + override fun getLocalAddress(): InetAddress { + return delegate.localAddress + } + + override fun getPort(): Int { + return delegate.port + } + + override fun getLocalPort(): Int { + return delegate.localPort + } + + override fun getRemoteSocketAddress(): SocketAddress { + return delegate.remoteSocketAddress + } + + override fun getLocalSocketAddress(): SocketAddress { + return delegate.localSocketAddress + } + + override fun getChannel(): SocketChannel { + return delegate.channel + } + + @Throws(IOException::class) + override fun getInputStream(): InputStream { + return delegate.getInputStream() + } + + @Throws(IOException::class) + override fun getOutputStream(): OutputStream { + return delegate.getOutputStream() + } + + @Throws(SocketException::class) + override fun setTcpNoDelay(on: Boolean) { + delegate.tcpNoDelay = on + } + + @Throws(SocketException::class) + override fun getTcpNoDelay(): Boolean { + return delegate.tcpNoDelay + } + + @Throws(SocketException::class) + override fun setSoLinger(on: Boolean, linger: Int) { + delegate.setSoLinger(on, linger) + } + + @Throws(SocketException::class) + override fun getSoLinger(): Int { + return delegate.soLinger + } + + @Throws(IOException::class) + override fun sendUrgentData(data: Int) { + delegate.sendUrgentData(data) + } + + @Throws(SocketException::class) + override fun setOOBInline(on: Boolean) { + delegate.oobInline = on + } + + @Throws(SocketException::class) + override fun getOOBInline(): Boolean { + return delegate.oobInline + } + + @Throws(SocketException::class) + override fun setSoTimeout(timeout: Int) { + delegate.soTimeout = timeout + } + + @Throws(SocketException::class) + override fun getSoTimeout(): Int { + return delegate.soTimeout + } + + @Throws(SocketException::class) + override fun setSendBufferSize(size: Int) { + delegate.sendBufferSize = size + } + + @Throws(SocketException::class) + override fun getSendBufferSize(): Int { + return delegate.sendBufferSize + } + + @Throws(SocketException::class) + override fun setReceiveBufferSize(size: Int) { + delegate.receiveBufferSize = size + } + + @Throws(SocketException::class) + override fun getReceiveBufferSize(): Int { + return delegate.receiveBufferSize + } + + @Throws(SocketException::class) + override fun setKeepAlive(on: Boolean) { + delegate.keepAlive = on + } + + @Throws(SocketException::class) + override fun getKeepAlive(): Boolean { + return delegate.keepAlive + } + + @Throws(SocketException::class) + override fun setTrafficClass(tc: Int) { + delegate.trafficClass = tc + } + + @Throws(SocketException::class) + override fun getTrafficClass(): Int { + return delegate.trafficClass + } + + @Throws(SocketException::class) + override fun setReuseAddress(on: Boolean) { + delegate.reuseAddress = on + } + + @Throws(SocketException::class) + override fun getReuseAddress(): Boolean { + return delegate.reuseAddress + } + + @Throws(IOException::class) + override fun close() { + delegate.close() + } + + @Throws(IOException::class) + override fun shutdownInput() { + delegate.shutdownInput() + } + + @Throws(IOException::class) + override fun shutdownOutput() { + delegate.shutdownOutput() + } + + override fun toString(): String { + return delegate.toString() + } + + override fun isConnected(): Boolean { + return delegate.isConnected + } + + override fun isBound(): Boolean { + return delegate.isBound + } + + override fun isClosed(): Boolean { + return delegate.isClosed + } + + override fun isInputShutdown(): Boolean { + return delegate.isInputShutdown + } + + override fun isOutputShutdown(): Boolean { + return delegate.isOutputShutdown + } + + override fun setPerformancePreferences(connectionTime: Int, latency: Int, bandwidth: Int) { + delegate.setPerformancePreferences(connectionTime, latency, bandwidth) + } + } +}