Guard against malformed group ids.

This commit is contained in:
Alan Evans
2020-04-20 12:01:31 -03:00
committed by Greyson Parrelli
parent 00ee6d0bbd
commit 9a8094cb8a
21 changed files with 200 additions and 94 deletions

View File

@@ -0,0 +1,15 @@
package org.thoughtcrime.securesms.groups;
public final class BadGroupIdException extends Exception {
BadGroupIdException(String message) {
super(message);
}
BadGroupIdException() {
super();
}
BadGroupIdException(Exception e) {
super(e);
}
}

View File

@@ -30,30 +30,46 @@ public abstract class GroupId {
return new GroupId.Mms(mmsGroupIdBytes);
}
public static @NonNull GroupId.V1 v1(byte[] gv1GroupIdBytes) {
public static @NonNull GroupId.V1 v1orThrow(byte[] gv1GroupIdBytes) {
try {
return v1(gv1GroupIdBytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}
public static @NonNull GroupId.V1 v1(byte[] gv1GroupIdBytes) throws BadGroupIdException {
if (gv1GroupIdBytes.length == V2_BYTE_LENGTH) {
throw new AssertionError();
throw new BadGroupIdException();
}
return new GroupId.V1(gv1GroupIdBytes);
}
public static GroupId.V1 createV1(@NonNull SecureRandom secureRandom) {
return v1(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH));
return v1orThrow(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH));
}
public static GroupId.Mms createMms(@NonNull SecureRandom secureRandom) {
return mms(Util.getSecretBytes(secureRandom, MMS_BYTE_LENGTH));
}
public static GroupId.V2 v2(@NonNull byte[] bytes) {
public static GroupId.V2 v2orThrow(@NonNull byte[] bytes) {
try {
return v2(bytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}
public static GroupId.V2 v2(@NonNull byte[] bytes) throws BadGroupIdException {
if (bytes.length != V2_BYTE_LENGTH) {
throw new AssertionError();
throw new BadGroupIdException();
}
return new GroupId.V2(bytes);
}
public static GroupId.V2 v2(@NonNull GroupIdentifier groupIdentifier) {
return v2(groupIdentifier.serialize());
return v2orThrow(groupIdentifier.serialize());
}
public static GroupId.V2 v2(@NonNull GroupMasterKey masterKey) {
@@ -62,25 +78,41 @@ public abstract class GroupId {
.getGroupIdentifier());
}
public static GroupId.Push push(byte[] bytes) {
public static GroupId.Push push(byte[] bytes) throws BadGroupIdException {
return bytes.length == V2_BYTE_LENGTH ? v2(bytes) : v1(bytes);
}
public static @NonNull GroupId parse(@NonNull String encodedGroupId) {
public static GroupId.Push pushOrThrow(byte[] bytes) {
try {
return push(bytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}
public static @NonNull GroupId parseOrThrow(@NonNull String encodedGroupId) {
try {
return parse(encodedGroupId);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}
public static @NonNull GroupId parse(@NonNull String encodedGroupId) throws BadGroupIdException {
try {
if (!isEncodedGroup(encodedGroupId)) {
throw new IOException("Invalid encoding");
throw new BadGroupIdException("Invalid encoding");
}
byte[] bytes = extractDecodedId(encodedGroupId);
return encodedGroupId.startsWith(ENCODED_MMS_GROUP_PREFIX) ? mms(bytes) : push(bytes);
} catch (IOException e) {
throw new AssertionError(e);
throw new BadGroupIdException(e);
}
}
public static @Nullable GroupId parseNullable(@Nullable String encodedGroupId) {
public static @Nullable GroupId parseNullable(@Nullable String encodedGroupId) throws BadGroupIdException {
if (encodedGroupId == null) {
return null;
}
@@ -88,6 +120,14 @@ public abstract class GroupId {
return parse(encodedGroupId);
}
public static @Nullable GroupId parseNullableOrThrow(@Nullable String encodedGroupId) {
if (encodedGroupId == null) {
return null;
}
return parseOrThrow(encodedGroupId);
}
public static boolean isEncodedGroup(@NonNull String groupId) {
return groupId.startsWith(ENCODED_SIGNAL_GROUP_PREFIX) || groupId.startsWith(ENCODED_MMS_GROUP_PREFIX);
}

View File

@@ -70,7 +70,7 @@ public final class GroupV1MessageProcessor {
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
SignalServiceGroup group = groupV1.get();
GroupId id = GroupId.v1(group.getGroupId());
GroupId id = GroupId.v1orThrow(group.getGroupId());
Optional<GroupRecord> record = database.getGroup(id);
if (record.isPresent() && group.getType() == Type.UPDATE) {
@@ -93,7 +93,7 @@ public final class GroupV1MessageProcessor {
boolean outgoing)
{
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId.V1 id = GroupId.v1(group.getGroupId());
GroupId.V1 id = GroupId.v1orThrow(group.getGroupId());
GroupContext.Builder builder = createGroupContext(group);
builder.setType(GroupContext.Type.UPDATE);
@@ -127,7 +127,7 @@ public final class GroupV1MessageProcessor {
{
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId.V1 id = GroupId.v1(group.getGroupId());
GroupId.V1 id = GroupId.v1orThrow(group.getGroupId());
Set<RecipientId> recordMembers = new HashSet<>(groupRecord.getMembers());
Set<RecipientId> messageMembers = new HashSet<>();
@@ -203,7 +203,7 @@ public final class GroupV1MessageProcessor {
boolean outgoing)
{
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId id = GroupId.v1(group.getGroupId());
GroupId id = GroupId.v1orThrow(group.getGroupId());
List<RecipientId> members = record.getMembers();
GroupContext.Builder builder = createGroupContext(group);
@@ -228,13 +228,13 @@ public final class GroupV1MessageProcessor {
{
if (group.getAvatar().isPresent()) {
ApplicationDependencies.getJobManager()
.add(new AvatarGroupsV1DownloadJob(GroupId.v1(group.getGroupId())));
.add(new AvatarGroupsV1DownloadJob(GroupId.v1orThrow(group.getGroupId())));
}
try {
if (outgoing) {
MmsDatabase mmsDatabase = DatabaseFactory.getMmsDatabase(context);
RecipientId recipientId = DatabaseFactory.getRecipientDatabase(context).getOrInsertFromGroupId(GroupId.v1(group.getGroupId()));
RecipientId recipientId = DatabaseFactory.getRecipientDatabase(context).getOrInsertFromGroupId(GroupId.v1orThrow(group.getGroupId()));
Recipient recipient = Recipient.resolved(recipientId);
OutgoingGroupMediaMessage outgoingMessage = new OutgoingGroupMediaMessage(recipient, storage, null, content.getTimestamp(), 0, false, null, Collections.emptyList(), Collections.emptyList());
long threadId = DatabaseFactory.getThreadDatabase(context).getThreadIdFor(recipient);
@@ -246,7 +246,7 @@ public final class GroupV1MessageProcessor {
} else {
SmsDatabase smsDatabase = DatabaseFactory.getSmsDatabase(context);
String body = Base64.encodeBytes(storage.toByteArray());
IncomingTextMessage incoming = new IncomingTextMessage(Recipient.externalPush(context, content.getSender()).getId(), content.getSenderDevice(), content.getTimestamp(), content.getServerTimestamp(), body, Optional.of(GroupId.v1(group.getGroupId())), 0, content.isNeedsReceipt());
IncomingTextMessage incoming = new IncomingTextMessage(Recipient.externalPush(context, content.getSender()).getId(), content.getSenderDevice(), content.getTimestamp(), content.getServerTimestamp(), body, Optional.of(GroupId.v1orThrow(group.getGroupId())), 0, content.isNeedsReceipt());
IncomingGroupMessage groupMessage = new IncomingGroupMessage(incoming, storage, body);
Optional<InsertResult> insertResult = smsDatabase.insertMessageInbox(groupMessage);

View File

@@ -37,7 +37,7 @@ public class PendingMemberInvitesActivity extends PassphraseRequiredActionBarAct
if (savedInstanceState == null) {
getSupportFragmentManager().beginTransaction()
.replace(R.id.container, PendingMemberInvitesFragment.newInstance(GroupId.parse(getIntent().getStringExtra(GROUP_ID)).requireV2()))
.replace(R.id.container, PendingMemberInvitesFragment.newInstance(GroupId.parseOrThrow(getIntent().getStringExtra(GROUP_ID)).requireV2()))
.commitNow();
}

View File

@@ -80,7 +80,7 @@ public class PendingMemberInvitesFragment extends Fragment {
public void onActivityCreated(@Nullable Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
GroupId.V2 groupId = GroupId.parse(Objects.requireNonNull(requireArguments().getString(GROUP_ID))).requireV2();
GroupId.V2 groupId = GroupId.parseOrThrow(Objects.requireNonNull(requireArguments().getString(GROUP_ID))).requireV2();
PendingMemberInvitesViewModel.Factory factory = new PendingMemberInvitesViewModel.Factory(requireContext(), groupId);