From e552b5160f7a861a797ad4eb79b7527b7e8ba11f Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Mon, 28 Feb 2022 11:22:58 -0500 Subject: [PATCH] Implement CdshV2Service. --- ...veryV3.java => ContactDiscoveryHsmV1.java} | 9 +- .../contacts/sync/DirectoryHelper.java | 2 +- .../api/SignalServiceAccountManager.java | 29 ++- .../api/services/CdshSocket.java | 51 ++++- .../api/services/CdshV1Service.java | 30 +-- .../api/services/CdshV2Service.java | 216 ++++++++++++++++++ 6 files changed, 300 insertions(+), 37 deletions(-) rename app/src/main/java/org/thoughtcrime/securesms/contacts/sync/{ContactDiscoveryV3.java => ContactDiscoveryHsmV1.java} (91%) create mode 100644 libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV2Service.java diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryV3.java b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryHsmV1.java similarity index 91% rename from app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryV3.java rename to app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryHsmV1.java index acfe806ede..c8e53db9b7 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryV3.java +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryHsmV1.java @@ -18,15 +18,14 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; import java.util.stream.Collectors; /** - * Uses CDS to map E164's to UUIDs. + * Uses CDSHv1 to map E164's to UUIDs. */ -class ContactDiscoveryV3 { +class ContactDiscoveryHsmV1 { - private static final String TAG = Log.tag(ContactDiscoveryV3.class); + private static final String TAG = Log.tag(ContactDiscoveryHsmV1.class); private static final int MAX_NUMBERS = 20_500; @@ -47,7 +46,7 @@ class ContactDiscoveryV3 { SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager(); try { - Map results = accountManager.getRegisteredUsersWithCdsh(sanitizedNumbers, BuildConfig.CDSH_PUBLIC_KEY, BuildConfig.CDSH_CODE_HASH); + Map results = accountManager.getRegisteredUsersWithCdshV1(sanitizedNumbers, BuildConfig.CDSH_PUBLIC_KEY, BuildConfig.CDSH_CODE_HASH); FuzzyPhoneNumberHelper.OutputResult outputResult = FuzzyPhoneNumberHelper.generateOutput(results, inputResult); return new DirectoryResult(outputResult.getNumbers(), outputResult.getRewrites(), ignoredNumbers); diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/DirectoryHelper.java b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/DirectoryHelper.java index 3bd5b5244c..e2f635b261 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/DirectoryHelper.java +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/DirectoryHelper.java @@ -233,7 +233,7 @@ public class DirectoryHelper { DirectoryResult result; if (FeatureFlags.cdsh()) { - result = ContactDiscoveryV3.getDirectoryResult(databaseNumbers, systemNumbers); + result = ContactDiscoveryHsmV1.getDirectoryResult(databaseNumbers, systemNumbers); } else { result = ContactDiscoveryV2.getDirectoryResult(context, databaseNumbers, systemNumbers); } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java index 0d76484196..88efc013f6 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java @@ -44,6 +44,7 @@ import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulRespons import org.whispersystems.signalservice.api.push.exceptions.NotFoundException; import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException; import org.whispersystems.signalservice.api.services.CdshV1Service; +import org.whispersystems.signalservice.api.services.CdshV2Service; import org.whispersystems.signalservice.api.storage.SignalStorageCipher; import org.whispersystems.signalservice.api.storage.SignalStorageManifest; import org.whispersystems.signalservice.api.storage.SignalStorageModels; @@ -505,7 +506,7 @@ public class SignalServiceAccountManager { } } - public Map getRegisteredUsersWithCdsh(Set e164numbers, String hexPublicKey, String hexCodeHash) + public Map getRegisteredUsersWithCdshV1(Set e164numbers, String hexPublicKey, String hexCodeHash) throws IOException { CdshAuthResponse auth = pushServiceSocket.getCdshAuth(); @@ -530,6 +531,32 @@ public class SignalServiceAccountManager { } } + public CdshV2Service.Response getRegisteredUsersWithCdshV2(Set previousE164s, Set newE164s, Map serviceIds, Optional token, String hexPublicKey, String hexCodeHash) + throws IOException + { + CdshAuthResponse auth = pushServiceSocket.getCdshAuth(); + CdshV2Service service = new CdshV2Service(configuration, hexPublicKey, hexCodeHash); + CdshV2Service.Request request = new CdshV2Service.Request(previousE164s, newE164s, serviceIds, token); + Single> single = service.getRegisteredUsers(auth.getUsername(), auth.getPassword(), request); + + ServiceResponse serviceResponse; + try { + serviceResponse = single.blockingGet(); + } catch (Exception e) { + throw new RuntimeException("Unexpected exception when retrieving registered users!", e); + } + + if (serviceResponse.getResult().isPresent()) { + return serviceResponse.getResult().get(); + } else if (serviceResponse.getApplicationError().isPresent()) { + throw new IOException(serviceResponse.getApplicationError().get()); + } else if (serviceResponse.getExecutionError().isPresent()) { + throw new IOException(serviceResponse.getExecutionError().get()); + } else { + throw new IOException("Missing result!"); + } + } + public Optional getStorageManifest(StorageKey storageKey) throws IOException { try { diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java index 536beb8d75..9fa7a414cf 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java @@ -1,31 +1,28 @@ package org.whispersystems.signalservice.api.services; +import org.signal.cds.ClientRequest; import org.signal.cds.ClientResponse; import org.signal.libsignal.hsmenclave.HsmEnclaveClient; import org.whispersystems.libsignal.logging.Log; import org.whispersystems.libsignal.util.Pair; -import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.TrustStore; import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException; import org.whispersystems.signalservice.api.util.Tls12SocketFactory; -import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager; import org.whispersystems.signalservice.internal.util.Hex; import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.util.Base64; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -40,6 +37,9 @@ import okhttp3.Response; import okhttp3.WebSocket; import okhttp3.WebSocketListener; +/** + * Handles the websocket and general lifecycle of a CDSH request. + */ final class CdshSocket { private static final String TAG = CdshSocket.class.getSimpleName(); @@ -49,11 +49,13 @@ final class CdshSocket { private final String baseUrl; private final String hexPublicKey; private final String hexCodeHash; + private final Version version; - CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { + CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash, Version version) { this.baseUrl = configuration.getSignalCdshUrls()[0].getUrl(); this.hexPublicKey = hexPublicKey; this.hexCodeHash = hexCodeHash; + this.version = version; Pair socketFactory = createTlsSocketFactory(configuration.getSignalCdshUrls()[0].getTrustStore()); @@ -73,11 +75,11 @@ final class CdshSocket { } } - Observable connect(String username, String password, List requests) { + Observable connect(String username, String password, List requests) { return Observable.create(emitter -> { - AtomicReference stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE); + AtomicReference stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE); - String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash); + String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash); Request request = new Request.Builder() .url(url) .addHeader("Authorization", basicAuth(username, password)) @@ -91,8 +93,10 @@ final class CdshSocket { enclave.completeHandshake(bytes.toByteArray()); stage.set(Stage.WAITING_FOR_RESPONSE); - for (byte[] request : requests) { - webSocket.send(okio.ByteString.of(enclave.establishedSend(request))); + for (ClientRequest request : requests) { + byte[] plaintextBytes = requestToBytes(request, version); + byte[] ciphertextBytes = enclave.establishedSend(plaintextBytes); + webSocket.send(okio.ByteString.of(ciphertextBytes)); } break; @@ -139,6 +143,17 @@ final class CdshSocket { }); } + private static byte[] requestToBytes(ClientRequest request, Version version) { + ByteArrayOutputStream requestStream = new ByteArrayOutputStream(); + try { + requestStream.write(version.getValue()); + requestStream.write(request.toByteArray()); + } catch (IOException e) { + throw new AssertionError("Failed to write bytes!"); + } + return requestStream.toByteArray(); + } + private static String basicAuth(String username, String password) { return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes(StandardCharsets.UTF_8)); } @@ -158,4 +173,18 @@ final class CdshSocket { private enum Stage { WAITING_TO_INITIALIZE, WAITING_FOR_RESPONSE, FAILURE } + + enum Version { + V1(1), V2(2); + + private final int value; + + Version(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java index 3f27b5b07e..70aac0d72f 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java @@ -30,7 +30,6 @@ public final class CdshV1Service { private static final String TAG = CdshV1Service.class.getSimpleName(); - private static final int VERSION = 1; private static final int MAX_E164S_PER_REQUEST = 5000; private static final UUID EMPTY_ACI = new UUID(0, 0); private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs @@ -38,14 +37,14 @@ public final class CdshV1Service { private final CdshSocket cdshSocket; public CdshV1Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { - this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash); + this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash, CdshSocket.Version.V1); } public Single>> getRegisteredUsers(String username, String password, Set e164Numbers) { List addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList()); return cdshSocket - .connect(username, password, buildPlaintextRequests(addressBook)) + .connect(username, password, buildClientRequests(addressBook)) .map(CdshV1Service::parseEntries) .collect(Collectors.toList()) .flatMap(pages -> { @@ -83,10 +82,10 @@ public final class CdshV1Service { return out; } - private static List buildPlaintextRequests(List addressBook) { - List out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1); - ByteString.Output e164Page = ByteString.newOutput(); - int pageSize = 0; + private static List buildClientRequests(List addressBook) { + List out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1); + ByteString.Output e164Page = ByteString.newOutput(); + int pageSize = 0; for (String address : addressBook) { if (pageSize >= MAX_E164S_PER_REQUEST) { @@ -111,17 +110,10 @@ public final class CdshV1Service { return out; } - private static byte[] e164sToRequest(ByteString e164s, boolean more) { - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { - outputStream.write(VERSION); - ClientRequest.newBuilder() - .setNewE164S(e164s) - .setHasMore(more) - .build() - .writeTo(outputStream); - return outputStream.toByteArray(); - } catch (IOException e) { - throw new AssertionError("Failed to write protobuf to the output stream?"); - } + private static ClientRequest e164sToRequest(ByteString e164s, boolean more) { + return ClientRequest.newBuilder() + .setNewE164S(e164s) + .setHasMore(more) + .build(); } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV2Service.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV2Service.java new file mode 100644 index 0000000000..5760708608 --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV2Service.java @@ -0,0 +1,216 @@ +package org.whispersystems.signalservice.api.services; + +import com.google.protobuf.ByteString; + +import org.signal.cds.ClientRequest; +import org.signal.cds.ClientResponse; +import org.signal.zkgroup.profiles.ProfileKey; +import org.whispersystems.libsignal.util.ByteUtil; +import org.whispersystems.libsignal.util.guava.Optional; +import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; +import org.whispersystems.signalservice.api.push.ACI; +import org.whispersystems.signalservice.api.push.PNI; +import org.whispersystems.signalservice.api.push.ServiceId; +import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException; +import org.whispersystems.signalservice.api.util.UuidUtil; +import org.whispersystems.signalservice.internal.ServiceResponse; +import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import io.reactivex.rxjava3.core.Single; + +/** + * Handles network interactions with CDSHv2, the HSM-backed CDS service. + */ +public final class CdshV2Service { + + private static final String TAG = CdshV2Service.class.getSimpleName(); + + private static final UUID EMPTY_UUID = new UUID(0, 0); + private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs + + private final CdshSocket cdshSocket; + + public CdshV2Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { + this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash, CdshSocket.Version.V2); + } + + public Single> getRegisteredUsers(String username, String password, Request request) { + return cdshSocket + .connect(username, password, buildClientRequests(request)) + .map(CdshV2Service::parseEntries) + .collect(Collectors.toList()) + .flatMap(pages -> { + byte[] token = null; + Map all = new HashMap<>(); + + for (Response page : pages) { + all.putAll(page.getResults()); + token = token == null ? page.getToken() : token; + } + + if (token == null) { + throw new IOException("No token found in response!"); + } + + return Single.just(new Response(all, token)); + }) + .map(result -> ServiceResponse.forResult(result, 200, null)) + .onErrorReturn(error -> { + if (error instanceof NonSuccessfulResponseCodeException) { + int status = ((NonSuccessfulResponseCodeException) error).getCode(); + return ServiceResponse.forApplicationError(error, status, null); + } else { + return ServiceResponse.forUnknownError(error); + } + }); + } + + private static Response parseEntries(ClientResponse clientResponse) { + byte[] token = !clientResponse.getToken().isEmpty() ? clientResponse.getToken().toByteArray() : null; + Map results = new HashMap<>(); + ByteBuffer parser = clientResponse.getE164PniAciTriples().asReadOnlyByteBuffer(); + + while (parser.remaining() >= RESPONSE_ITEM_SIZE) { + String e164 = "+" + parser.getLong(); + UUID pniUuid = new UUID(parser.getLong(), parser.getLong()); + UUID aciUuid = new UUID(parser.getLong(), parser.getLong()); + + if (!pniUuid.equals(EMPTY_UUID)) { + PNI pni = PNI.from(pniUuid); + ACI aci = aciUuid.equals(EMPTY_UUID) ? null : ACI.from(aciUuid); + results.put(e164, new ResponseItem(pni, Optional.fromNullable(aci))); + } + } + + return new Response(results, token); + } + + private static List buildClientRequests(Request request) { + List previousE164s = parseAndSortE164Strings(request.previousE164s); + List newE164s = parseAndSortE164Strings(request.newE164s); + List removedE164s = parseAndSortE164Strings(request.removedE164s); + + return Collections.singletonList(ClientRequest.newBuilder() + .setPrevE164S(toByteString(previousE164s)) + .setNewE164S(toByteString(newE164s)) + .setDiscardE164S(toByteString(removedE164s)) + .setAciUakPairs(toByteString(request.serviceIds)) + .setToken(ByteString.copyFrom(request.token)) + .setHasMore(false) + .build()); + } + + private static ByteString toByteString(List numbers) { + ByteString.Output os = ByteString.newOutput(); + + for (long number : numbers) { + try { + os.write(ByteUtil.longToByteArray(number)); + } catch (IOException e) { + throw new AssertionError("Failed to write long to ByteString", e); + } + } + + return os.toByteString(); + } + + private static ByteString toByteString(Map serviceIds) { + ByteString.Output os = ByteString.newOutput(); + + for (Map.Entry entry : serviceIds.entrySet()) { + try { + os.write(UuidUtil.toByteArray(entry.getKey().uuid())); + os.write(UnidentifiedAccess.deriveAccessKeyFrom(entry.getValue())); + } catch (IOException e) { + throw new AssertionError("Failed to write long to ByteString", e); + } + } + + return os.toByteString(); + } + + private static List parseAndSortE164Strings(Collection e164s) { + return e164s.stream() + .map(Long::parseLong) + .sorted() + .collect(Collectors.toList()); + + } + + public static final class Request { + private final Set previousE164s; + private final Set newE164s; + private final Set removedE164s; + + private final Map serviceIds; + + private final byte[] token; + + public Request(Set previousE164s, Set newE164s, Map serviceIds, Optional token) { + this.previousE164s = previousE164s; + this.newE164s = newE164s; + this.removedE164s = Collections.emptySet(); + this.serviceIds = serviceIds; + this.token = token.isPresent() ? token.get() : new byte[32]; + } + + public int totalE164s() { + return previousE164s.size() + newE164s.size() - removedE164s.size(); + } + + public int serviceIdSize() { + return previousE164s.size() + newE164s.size() + removedE164s.size() + serviceIds.size(); + } + } + + public static final class Response { + private final Map results; + private final byte[] token; + + public Response(Map results, byte[] token) { + this.results = results; + this.token = token; + } + + public Map getResults() { + return results; + } + + public byte[] getToken() { + return token; + } + } + + public static final class ResponseItem { + private final PNI pni; + private final Optional aci; + + public ResponseItem(PNI pni, Optional aci) { + this.pni = pni; + this.aci = aci; + } + + public PNI getPni() { + return pni; + } + + public Optional getAci() { + return aci; + } + + public boolean hasAci() { + return aci.isPresent(); + } + } +}