Re-organize gradle modules.

This commit is contained in:
Greyson Parrelli
2025-12-31 11:56:13 -05:00
committed by jeffrey-signal
parent f4863efb2e
commit e162eb27c7
1444 changed files with 111 additions and 144 deletions

View File

@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest
xmlns:android="http://schemas.android.com/apk/res/android">
</manifest>

View File

@@ -0,0 +1,10 @@
package androidx.documentfile.provider
/**
* Located in androidx package as [TreeDocumentFile] is package protected.
*
* @return true if can be used like a tree document file (e.g., use content resolver queries)
*/
fun DocumentFile.isTreeDocumentFile(): Boolean {
return this is TreeDocumentFile
}

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.Activity
import android.os.Build
import androidx.annotation.AnimRes
val Activity.OVERRIDE_TRANSITION_OPEN_COMPAT: Int get() = 0
val Activity.OVERRIDE_TRANSITION_CLOSE_COMPAT: Int get() = 1
fun Activity.overrideActivityTransitionCompat(overrideType: Int, @AnimRes enterAnim: Int, @AnimRes exitAnim: Int) {
if (Build.VERSION.SDK_INT >= 34) {
overrideActivityTransition(overrideType, enterAnim, exitAnim)
} else {
@Suppress("DEPRECATION")
overridePendingTransition(enterAnim, exitAnim)
}
}
fun Activity.isInMultiWindowModeCompat(): Boolean {
return if (Build.VERSION.SDK_INT >= 24) {
isInMultiWindowMode
} else {
false
}
}

View File

@@ -0,0 +1,26 @@
package org.signal.core.util;
import android.app.ActivityManager;
import android.content.Context;
import android.content.Intent;
import androidx.annotation.NonNull;
import androidx.core.content.ContextCompat;
public final class AppUtil {
private AppUtil() {}
/**
* Restarts the application. Should generally only be used for internal tools.
*/
public static void restart(@NonNull Context context) {
String packageName = context.getPackageName();
Intent defaultIntent = context.getPackageManager().getLaunchIntentForPackage(packageName);
defaultIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK);
context.startActivity(defaultIntent);
Runtime.getRuntime().exit(0);
}
}

View File

@@ -0,0 +1,88 @@
package org.signal.core.util
import android.database.Cursor
import kotlin.math.max
class AsciiArt {
private class Table(
private val columns: List<String>,
private val rows: List<List<String>>
) {
override fun toString(): String {
val columnWidths = columns.map { column -> column.length }.toIntArray()
rows.forEach { row: List<String> ->
columnWidths.forEachIndexed { index, currentMax ->
columnWidths[index] = max(row[index].length, currentMax)
}
}
val builder = StringBuilder()
columns.forEachIndexed { index, column ->
builder.append(COLUMN_DIVIDER).append(" ").append(rightPad(column, columnWidths[index])).append(" ")
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
columnWidths.forEach { width ->
builder.append(COLUMN_DIVIDER)
builder.append(ROW_DIVIDER.repeat(width + 2))
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
rows.forEach { row ->
row.forEachIndexed { index, column ->
builder.append(COLUMN_DIVIDER).append(" ").append(rightPad(column, columnWidths[index])).append(" ")
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
}
return builder.toString()
}
}
companion object {
private const val COLUMN_DIVIDER = "|"
private const val ROW_DIVIDER = "-"
/**
* Will return a string representing a table of the provided cursor. The caller is responsible for the lifecycle of the cursor.
*/
@JvmStatic
fun tableFor(cursor: Cursor): String {
val columns: MutableList<String> = mutableListOf()
val rows: MutableList<List<String>> = mutableListOf()
columns.addAll(cursor.columnNames)
while (cursor.moveToNext()) {
val row: MutableList<String> = mutableListOf()
for (i in 0 until columns.size) {
row += cursor.getString(i)
}
rows += row
}
return Table(columns, rows).toString()
}
private fun rightPad(value: String, length: Int): String {
if (value.length >= length) {
return value
}
val out = java.lang.StringBuilder(value)
while (out.length < length) {
out.append(" ")
}
return out.toString()
}
}
}

View File

@@ -0,0 +1,146 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.util.Iterator;
public abstract class BreakIteratorCompat implements Iterable<CharSequence> {
public static final int DONE = -1;
private CharSequence charSequence;
public abstract int first();
public abstract int next();
public void setText(CharSequence charSequence) {
this.charSequence = charSequence;
}
public static BreakIteratorCompat getInstance() {
if (Build.VERSION.SDK_INT >= 24) {
return new AndroidIcuBreakIterator();
} else {
return new FallbackBreakIterator();
}
}
public int countBreaks() {
int breakCount = 0;
first();
while (next() != DONE) {
breakCount++;
}
return breakCount;
}
@Override
public @NonNull Iterator<CharSequence> iterator() {
return new Iterator<CharSequence>() {
int index1 = BreakIteratorCompat.this.first();
int index2 = BreakIteratorCompat.this.next();
@Override
public boolean hasNext() {
return index2 != DONE;
}
@Override
public CharSequence next() {
CharSequence c = index2 != DONE ? charSequence.subSequence(index1, index2) : "";
index1 = index2;
index2 = BreakIteratorCompat.this.next();
return c;
}
};
}
/**
* Take {@param atMost} graphemes from the start of string.
*/
public final CharSequence take(int atMost) {
if (atMost <= 0) return "";
StringBuilder stringBuilder = new StringBuilder(charSequence.length());
int count = 0;
for (CharSequence grapheme : this) {
stringBuilder.append(grapheme);
count++;
if (count >= atMost) break;
}
return stringBuilder.toString();
}
/**
* An BreakIteratorCompat implementation that delegates calls to `android.icu.text.BreakIterator`.
* This class handles grapheme clusters fine but requires Android API >= 24.
*/
@RequiresApi(24)
private static class AndroidIcuBreakIterator extends BreakIteratorCompat {
private final android.icu.text.BreakIterator breakIterator = android.icu.text.BreakIterator.getCharacterInstance();
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public void setText(CharSequence charSequence) {
super.setText(charSequence);
if (Build.VERSION.SDK_INT >= 29) {
breakIterator.setText(charSequence);
} else {
breakIterator.setText(charSequence.toString());
}
}
}
/**
* An BreakIteratorCompat implementation that delegates calls to `java.text.BreakIterator`.
* This class may or may not handle grapheme clusters well depending on the underlying implementation.
* In the emulator, API 23 implements ICU version of the BreakIterator so that it handles grapheme
* clusters fine. But API 21 implements RuleBasedIterator which does not handle grapheme clusters.
* <p>
* If it doesn't handle grapheme clusters correctly, in most cases the combined characters are
* broken up into pieces when the code tries to trim a string. For example, an emoji that is
* a combination of a person, gender and skin tone, trimming the character using this class may result
* in trimming the parts of the character, e.g. a dark skin frowning woman emoji may result in
* a neutral skin frowning woman emoji.
*/
private static class FallbackBreakIterator extends BreakIteratorCompat {
private final java.text.BreakIterator breakIterator = java.text.BreakIterator.getCharacterInstance();
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public void setText(CharSequence charSequence) {
super.setText(charSequence);
breakIterator.setText(charSequence.toString());
}
}
}

View File

@@ -0,0 +1,44 @@
@file:JvmName("BundleExtensions")
package org.signal.core.util
import android.os.Build
import android.os.Bundle
import android.os.Parcelable
import java.io.Serializable
fun <T : Serializable> Bundle.getSerializableCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getSerializable(key, clazz)
} else {
@Suppress("DEPRECATION", "UNCHECKED_CAST")
this.getSerializable(key) as T?
}
}
fun <T : Parcelable> Bundle.getParcelableCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelable(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelable(key)
}
}
fun <T : Parcelable> Bundle.requireParcelableCompat(key: String, clazz: Class<T>): T {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelable(key, clazz)!!
} else {
@Suppress("DEPRECATION")
this.getParcelable(key)!!
}
}
fun <T : Parcelable> Bundle.getParcelableArrayListCompat(key: String, clazz: Class<T>): ArrayList<T>? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableArrayList(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableArrayList(key)
}
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.text.InputFilter
import android.text.Spanned
/**
* An [InputFilter] that prevents the target text from growing beyond [byteLimit] bytes when using UTF-8 encoding.
*/
class ByteLimitInputFilter(private val byteLimit: Int) : InputFilter {
override fun filter(source: CharSequence?, start: Int, end: Int, dest: Spanned?, dstart: Int, dend: Int): CharSequence? {
if (source == null || dest == null) {
return null
}
val insertText = source.subSequence(start, end)
val beforeText = dest.subSequence(0, dstart)
val afterText = dest.subSequence(dend, dest.length)
val insertByteLength = insertText.utf8Size()
val beforeByteLength = beforeText.utf8Size()
val afterByteLength = afterText.utf8Size()
val resultByteSize = beforeByteLength + insertByteLength + afterByteLength
if (resultByteSize <= byteLimit) {
return null
}
val availableBytes = byteLimit - beforeByteLength - afterByteLength
if (availableBytes <= 0) {
return ""
}
return truncateToByteLimit(insertText, availableBytes)
}
private fun truncateToByteLimit(text: CharSequence, maxBytes: Int): CharSequence {
var byteCount = 0
var charIndex = 0
while (charIndex < text.length) {
val char = text[charIndex]
val charBytes = when {
char.code < 0x80 -> 1
char.code < 0x800 -> 2
char.isHighSurrogate() -> {
if (charIndex + 1 < text.length && text[charIndex + 1].isLowSurrogate()) {
4
} else {
3
}
}
char.isLowSurrogate() -> 3 // Treat orphaned low surrogate as 3 bytes
else -> 3
}
if (byteCount + charBytes > maxBytes) {
break
}
byteCount += charBytes
charIndex++
if (char.isHighSurrogate() && charIndex < text.length && text[charIndex].isLowSurrogate()) {
charIndex++
}
}
return text.subSequence(0, charIndex)
}
}

View File

@@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
/**
* A copy of [okio.utf8Size] that works on [CharSequence].
*/
fun CharSequence.utf8Size(): Int {
var result = 0
var i = 0
while (i < this.length) {
val c = this[i].code
if (c < 0x80) {
// A 7-bit character with 1 byte.
result++
i++
} else if (c < 0x800) {
// An 11-bit character with 2 bytes.
result += 2
i++
} else if (c < 0xd800 || c > 0xdfff) {
// A 16-bit character with 3 bytes.
result += 3
i++
} else {
val low = if (i + 1 < this.length) this[i + 1].code else 0
if (c > 0xdbff || low < 0xdc00 || low > 0xdfff) {
// A malformed surrogate, which yields '?'.
result++
i++
} else {
// A 21-bit character with 4 bytes.
result += 4
i += 2
}
}
}
return result
}

View File

@@ -0,0 +1,124 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.util.Iterator;
/**
* Iterates over a string treating a surrogate pair and a grapheme cluster a single character.
*/
public final class CharacterIterable implements Iterable<String> {
private final String string;
public CharacterIterable(@NonNull String string) {
this.string = string;
}
@Override
public @NonNull Iterator<String> iterator() {
return new CharacterIterator();
}
private class CharacterIterator implements Iterator<String> {
private static final int UNINITIALIZED = -2;
private final BreakIteratorCompat breakIterator;
private int lastIndex = UNINITIALIZED;
CharacterIterator() {
this.breakIterator = Build.VERSION.SDK_INT >= 24 ? new AndroidIcuBreakIterator(string)
: new FallbackBreakIterator(string);
}
@Override
public boolean hasNext() {
if (lastIndex == UNINITIALIZED) {
lastIndex = breakIterator.first();
}
return !breakIterator.isDone(lastIndex);
}
@Override
public String next() {
int firstIndex = lastIndex;
lastIndex = breakIterator.next();
return string.substring(firstIndex, lastIndex);
}
}
private interface BreakIteratorCompat {
int first();
int next();
boolean isDone(int index);
}
/**
* An BreakIteratorCompat implementation that delegates calls to `android.icu.text.BreakIterator`.
* This class handles grapheme clusters fine but requires Android API >= 24.
*/
@RequiresApi(24)
private static class AndroidIcuBreakIterator implements BreakIteratorCompat {
private final android.icu.text.BreakIterator breakIterator = android.icu.text.BreakIterator.getCharacterInstance();
public AndroidIcuBreakIterator(@NonNull String string) {
breakIterator.setText(string);
}
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public boolean isDone(int index) {
return index == android.icu.text.BreakIterator.DONE;
}
}
/**
* An BreakIteratorCompat implementation that delegates calls to `java.text.BreakIterator`.
* This class may or may not handle grapheme clusters well depending on the underlying implementation.
* In the emulator, API 23 implements ICU version of the BreakIterator so that it handles grapheme
* clusters fine. But API 21 implements RuleBasedIterator which does not handle grapheme clusters.
* <p>
* If it doesn't handle grapheme clusters correctly, in most cases the combined characters are
* broken up into pieces when the code tries to trim a string. For example, an emoji that is
* a combination of a person, gender and skin tone, trimming the character using this class may result
* in trimming the parts of the character, e.g. a dark skin frowning woman emoji may result in
* a neutral skin frowning woman emoji.
*/
private static class FallbackBreakIterator implements BreakIteratorCompat {
private final java.text.BreakIterator breakIterator = java.text.BreakIterator.getCharacterInstance();
public FallbackBreakIterator(@NonNull String string) {
breakIterator.setText(string);
}
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public boolean isDone(int index) {
return index == java.text.BreakIterator.DONE;
}
}
}

View File

@@ -0,0 +1,31 @@
package org.signal.core.util
import java.util.Collections
/**
* Flattens a List of Map<K, V> into a Map<K, V> using the + operator.
*
* @return A Map containing all of the K, V pairings of the maps contained in the original list.
*/
fun <K, V> List<Map<K, V>>.flatten(): Map<K, V> = foldRight(emptyMap()) { a, b -> a + b }
/**
* Swaps the elements at the specified positions and returns the result in a new immutable list.
*
* @param i the index of one element to be swapped.
* @param j the index of the other element to be swapped.
*
* @throws IndexOutOfBoundsException if either i or j is out of range.
*/
fun <E> List<E>.swap(i: Int, j: Int): List<E> {
val mutableCopy = this.toMutableList()
Collections.swap(mutableCopy, i, j)
return mutableCopy.toList()
}
/**
* Returns the item wrapped in a list, or an empty list of the item is null.
*/
fun <E> E?.asList(): List<E> {
return if (this == null) emptyList() else listOf(this)
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.content.ContentResolver
import android.net.Uri
import android.provider.OpenableColumns
import okio.IOException
@Throws(IOException::class)
fun ContentResolver.getLength(uri: Uri): Long? {
return this.query(uri, arrayOf(OpenableColumns.SIZE), null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.requireLongOrNull(OpenableColumns.SIZE)
} else {
null
}
} ?: openInputStream(uri)?.use { it.readLength() }
}

View File

@@ -0,0 +1,13 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.DownloadManager
import android.content.Context
fun Context.getDownloadManager(): DownloadManager {
return this.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
}

View File

@@ -0,0 +1,187 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.signal.core.util;
public class Conversions {
public static byte intsToByteHighAndLow(int highValue, int lowValue) {
return (byte)((highValue << 4 | lowValue) & 0xFF);
}
public static int highBitsToInt(byte value) {
return (value & 0xFF) >> 4;
}
public static int lowBitsToInt(byte value) {
return (value & 0xF);
}
public static int highBitsToMedium(int value) {
return (value >> 12);
}
public static int lowBitsToMedium(int value) {
return (value & 0xFFF);
}
public static byte[] shortToByteArray(int value) {
byte[] bytes = new byte[2];
shortToByteArray(bytes, 0, value);
return bytes;
}
public static int shortToByteArray(byte[] bytes, int offset, int value) {
bytes[offset+1] = (byte)value;
bytes[offset] = (byte)(value >> 8);
return 2;
}
public static int shortToLittleEndianByteArray(byte[] bytes, int offset, int value) {
bytes[offset] = (byte)value;
bytes[offset+1] = (byte)(value >> 8);
return 2;
}
public static byte[] mediumToByteArray(int value) {
byte[] bytes = new byte[3];
mediumToByteArray(bytes, 0, value);
return bytes;
}
public static int mediumToByteArray(byte[] bytes, int offset, int value) {
bytes[offset + 2] = (byte)value;
bytes[offset + 1] = (byte)(value >> 8);
bytes[offset] = (byte)(value >> 16);
return 3;
}
public static byte[] intToByteArray(int value) {
byte[] bytes = new byte[4];
intToByteArray(bytes, 0, value);
return bytes;
}
public static int intToByteArray(byte[] bytes, int offset, int value) {
bytes[offset + 3] = (byte)value;
bytes[offset + 2] = (byte)(value >> 8);
bytes[offset + 1] = (byte)(value >> 16);
bytes[offset] = (byte)(value >> 24);
return 4;
}
public static int intToLittleEndianByteArray(byte[] bytes, int offset, int value) {
bytes[offset] = (byte)value;
bytes[offset+1] = (byte)(value >> 8);
bytes[offset+2] = (byte)(value >> 16);
bytes[offset+3] = (byte)(value >> 24);
return 4;
}
public static byte[] longToByteArray(long l) {
byte[] bytes = new byte[8];
longToByteArray(bytes, 0, l);
return bytes;
}
public static int longToByteArray(byte[] bytes, int offset, long value) {
bytes[offset + 7] = (byte)value;
bytes[offset + 6] = (byte)(value >> 8);
bytes[offset + 5] = (byte)(value >> 16);
bytes[offset + 4] = (byte)(value >> 24);
bytes[offset + 3] = (byte)(value >> 32);
bytes[offset + 2] = (byte)(value >> 40);
bytes[offset + 1] = (byte)(value >> 48);
bytes[offset] = (byte)(value >> 56);
return 8;
}
public static int longTo4ByteArray(byte[] bytes, int offset, long value) {
bytes[offset + 3] = (byte)value;
bytes[offset + 2] = (byte)(value >> 8);
bytes[offset + 1] = (byte)(value >> 16);
bytes[offset + 0] = (byte)(value >> 24);
return 4;
}
public static int byteArrayToShort(byte[] bytes) {
return byteArrayToShort(bytes, 0);
}
public static int byteArrayToShort(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 8 | (bytes[offset + 1] & 0xff);
}
// The SSL patented 3-byte Value.
public static int byteArrayToMedium(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 16 |
(bytes[offset + 1] & 0xff) << 8 |
(bytes[offset + 2] & 0xff);
}
public static int byteArrayToInt(byte[] bytes) {
return byteArrayToInt(bytes, 0);
}
public static int byteArrayToInt(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 24 |
(bytes[offset + 1] & 0xff) << 16 |
(bytes[offset + 2] & 0xff) << 8 |
(bytes[offset + 3] & 0xff);
}
public static int byteArrayToIntLittleEndian(byte[] bytes, int offset) {
return
(bytes[offset + 3] & 0xff) << 24 |
(bytes[offset + 2] & 0xff) << 16 |
(bytes[offset + 1] & 0xff) << 8 |
(bytes[offset] & 0xff);
}
public static long byteArrayToLong(byte[] bytes) {
return byteArrayToLong(bytes, 0);
}
public static long byteArray4ToLong(byte[] bytes, int offset) {
return
((bytes[offset + 0] & 0xffL) << 24) |
((bytes[offset + 1] & 0xffL) << 16) |
((bytes[offset + 2] & 0xffL) << 8) |
((bytes[offset + 3] & 0xffL));
}
public static long byteArrayToLong(byte[] bytes, int offset) {
return
((bytes[offset] & 0xffL) << 56) |
((bytes[offset + 1] & 0xffL) << 48) |
((bytes[offset + 2] & 0xffL) << 40) |
((bytes[offset + 3] & 0xffL) << 32) |
((bytes[offset + 4] & 0xffL) << 24) |
((bytes[offset + 5] & 0xffL) << 16) |
((bytes[offset + 6] & 0xffL) << 8) |
((bytes[offset + 7] & 0xffL));
}
public static int toIntExact(long value) {
if ((int)value != value) {
throw new ArithmeticException("integer overflow");
}
return (int)value;
}
}

View File

@@ -0,0 +1,273 @@
package org.signal.core.util
import android.database.Cursor
import androidx.core.database.getIntOrNull
import androidx.core.database.getLongOrNull
import androidx.core.database.getStringOrNull
import java.util.Optional
fun Cursor.requireString(column: String): String? {
return CursorUtil.requireString(this, column)
}
fun Cursor.requireNonNullString(column: String): String {
return CursorUtil.requireString(this, column)!!
}
fun Cursor.optionalString(column: String): Optional<String> {
return CursorUtil.getString(this, column)
}
fun Cursor.requireInt(column: String): Int {
return CursorUtil.requireInt(this, column)
}
fun Cursor.requireIntOrNull(column: String): Int? {
return this.getIntOrNull(this.getColumnIndexOrThrow(column))
}
fun Cursor.optionalInt(column: String): Optional<Int> {
return CursorUtil.getInt(this, column)
}
fun Cursor.requireFloat(column: String): Float {
return CursorUtil.requireFloat(this, column)
}
fun Cursor.requireLong(column: String): Long {
return CursorUtil.requireLong(this, column)
}
fun Cursor.requireLongOrNull(column: String): Long? {
return this.getLongOrNull(this.getColumnIndexOrThrow(column))
}
fun Cursor.optionalLong(column: String): Optional<Long> {
return CursorUtil.getLong(this, column)
}
fun Cursor.requireBoolean(column: String): Boolean {
return CursorUtil.requireInt(this, column) != 0
}
fun Cursor.optionalBoolean(column: String): Optional<Boolean> {
return CursorUtil.getBoolean(this, column)
}
fun Cursor.requireBlob(column: String): ByteArray? {
return CursorUtil.requireBlob(this, column)
}
fun Cursor.requireNonNullBlob(column: String): ByteArray {
return CursorUtil.requireBlob(this, column)!!
}
fun Cursor.optionalBlob(column: String): Optional<ByteArray> {
return CursorUtil.getBlob(this, column)
}
fun Cursor.isNull(column: String): Boolean {
return CursorUtil.isNull(this, column)
}
fun <T> Cursor.requireObject(column: String, serializer: LongSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireLong(this, column))
}
fun <T> Cursor.requireObject(column: String, serializer: StringSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireString(this, column))
}
fun <T> Cursor.requireObject(column: String, serializer: IntSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireInt(this, column))
}
@JvmOverloads
fun Cursor.readToSingleLong(defaultValue: Long = 0): Long {
return readToSingleLongOrNull() ?: defaultValue
}
fun Cursor.readToSingleLongOrNull(): Long? {
return use {
if (it.moveToFirst()) {
it.getLongOrNull(0)
} else {
null
}
}
}
fun <T> Cursor.readToSingleObject(serializer: BaseSerializer<T, Cursor, *>): T? {
return use {
if (it.moveToFirst()) {
serializer.deserialize(it)
} else {
null
}
}
}
fun <T> Cursor.readToSingleObject(mapper: (Cursor) -> T): T? {
return use {
if (it.moveToFirst()) {
mapper(it)
} else {
null
}
}
}
@JvmOverloads
fun Cursor.readToSingleInt(defaultValue: Int = 0): Int {
return use {
if (it.moveToFirst()) {
it.getInt(0)
} else {
defaultValue
}
}
}
fun Cursor.readToSingleIntOrNull(): Int? {
return use {
if (it.moveToFirst()) {
it.getIntOrNull(0)
} else {
null
}
}
}
fun Cursor.readToSingleBoolean(defaultValue: Boolean = false): Boolean {
return use {
if (it.moveToFirst()) {
it.getInt(0) != 0
} else {
defaultValue
}
}
}
@JvmOverloads
inline fun <T> Cursor.readToList(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): List<T> {
val list = mutableListOf<T>()
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
list += mapper(this)
}
}
}
return list
}
@JvmOverloads
inline fun <K, V> Cursor.readToMap(predicate: (Pair<K, V>) -> Boolean = { true }, mapper: (Cursor) -> Pair<K, V>): Map<K, V> {
return readToList(predicate, mapper).associate { it }
}
/**
* Groups the cursor by the given key, and returns a map of keys to lists of values.
*/
inline fun <K, V> Cursor.groupBy(mapper: (Cursor) -> Pair<K, V>): Map<K, List<V>> {
val map: MutableMap<K, MutableList<V>> = mutableMapOf()
use {
while (moveToNext()) {
val pair = mapper(this)
val list = map.getOrPut(pair.first) { mutableListOf() }
list += pair.second
}
}
return map
}
inline fun <T> Cursor.readToSet(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): Set<T> {
val set = mutableSetOf<T>()
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
set += mapper(this)
}
}
}
return set
}
inline fun <T> Cursor.firstOrNull(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): T? {
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
return record
}
}
}
return null
}
inline fun Cursor.forEach(operation: (Cursor) -> Unit) {
use {
while (moveToNext()) {
operation(this)
}
}
}
inline fun Cursor.forEachIndexed(operation: (Int, Cursor) -> Unit) {
use {
var i = 0
while (moveToNext()) {
operation(i++, this)
}
}
}
fun Cursor.iterable(): Iterable<Cursor> {
return CursorIterable(this)
}
fun Boolean.toInt(): Int = if (this) 1 else 0
/**
* Renders the entire cursor row as a string.
* Not necessarily used in the app, but very useful to have available when debugging.
*/
fun Cursor.rowToString(): String {
val builder = StringBuilder()
for (i in 0 until this.columnCount) {
builder
.append(this.getColumnName(i))
.append("=")
.append(this.getStringOrNull(i))
if (i < this.columnCount - 1) {
builder.append(", ")
}
}
return builder.toString()
}
private class CursorIterable(private val cursor: Cursor) : Iterable<Cursor> {
override fun iterator(): Iterator<Cursor> {
return CursorIterator(cursor)
}
}
private class CursorIterator(private val cursor: Cursor) : Iterator<Cursor> {
override fun hasNext(): Boolean {
return !cursor.isClosed && cursor.count > 0 && !cursor.isLast && !cursor.isAfterLast
}
override fun next(): Cursor {
return if (cursor.moveToNext()) {
cursor
} else {
throw NoSuchElementException()
}
}
}

View File

@@ -0,0 +1,107 @@
package org.signal.core.util;
import android.database.Cursor;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import java.util.Optional;
import java.util.function.Function;
public final class CursorUtil {
private CursorUtil() {}
public static String requireString(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getString(cursor.getColumnIndexOrThrow(column));
}
public static int requireInt(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getInt(cursor.getColumnIndexOrThrow(column));
}
public static float requireFloat(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getFloat(cursor.getColumnIndexOrThrow(column));
}
public static long requireLong(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getLong(cursor.getColumnIndexOrThrow(column));
}
public static boolean requireBoolean(@NonNull Cursor cursor, @NonNull String column) {
return requireInt(cursor, column) != 0;
}
public static byte[] requireBlob(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getBlob(cursor.getColumnIndexOrThrow(column));
}
public static boolean isNull(@NonNull Cursor cursor, @NonNull String column) {
return cursor.isNull(cursor.getColumnIndexOrThrow(column));
}
public static boolean requireMaskedBoolean(@NonNull Cursor cursor, @NonNull String column, int position) {
return Bitmask.read(requireLong(cursor, column), position);
}
public static int requireMaskedInt(@NonNull Cursor cursor, @NonNull String column, int position, int flagBitSize) {
return Conversions.toIntExact(Bitmask.read(requireLong(cursor, column), position, flagBitSize));
}
public static Optional<String> getString(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.ofNullable(requireString(cursor, column));
}
}
public static Optional<Integer> getInt(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireInt(cursor, column));
}
}
public static Optional<Long> getLong(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireLong(cursor, column));
}
}
public static Optional<Boolean> getBoolean(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireBoolean(cursor, column));
}
}
public static Optional<byte[]> getBlob(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.ofNullable(requireBlob(cursor, column));
}
}
/**
* Reads each column as a string, and concatenates them together into a single string separated by |
*/
public static String readRowAsString(@NonNull Cursor cursor) {
StringBuilder row = new StringBuilder();
for (int i = 0, len = cursor.getColumnCount(); i < len; i++) {
row.append(cursor.getString(i));
if (i < len - 1) {
row.append(" | ");
}
}
return row.toString();
}
}

View File

@@ -0,0 +1,7 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
public interface DatabaseId {
@NonNull String serialize();
}

View File

@@ -0,0 +1,73 @@
package org.signal.core.util;
import android.content.res.Resources;
import androidx.annotation.Dimension;
import androidx.annotation.Px;
/**
* Core utility for converting different dimensional values.
*/
public enum DimensionUnit {
PIXELS {
@Override
@Px
public float toPixels(@Px float pixels) {
return pixels;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Px float pixels) {
return pixels / Resources.getSystem().getDisplayMetrics().density;
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Px float pixels) {
return pixels / Resources.getSystem().getDisplayMetrics().scaledDensity;
}
},
DP {
@Override
@Px
public float toPixels(@Dimension(unit = Dimension.DP) float dp) {
return dp * Resources.getSystem().getDisplayMetrics().density;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Dimension(unit = Dimension.DP) float dp) {
return dp;
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Dimension(unit = Dimension.DP) float dp) {
return PIXELS.toSp(toPixels(dp));
}
},
SP {
@Override
@Px
public float toPixels(@Dimension(unit = Dimension.SP) float sp) {
return sp * Resources.getSystem().getDisplayMetrics().scaledDensity;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Dimension(unit = Dimension.SP) float sp) {
return PIXELS.toDp(toPixels(sp));
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Dimension(unit = Dimension.SP) float sp) {
return sp;
}
};
public abstract float toPixels(float value);
public abstract float toDp(float value);
public abstract float toSp(float value);
}

View File

@@ -0,0 +1,27 @@
package org.signal.core.util
import androidx.annotation.Px
/**
* Converts the given Float DP value into Pixels.
*/
@get:Px
val Float.dp: Float get() = DimensionUnit.DP.toPixels(this)
/**
* Converts the given Int DP value into Pixels
*/
@get:Px
val Int.dp: Int get() = this.toFloat().dp.toInt()
/**
* Converts the given Float SP value into Pixels.
*/
@get:Px
val Float.sp: Float get() = DimensionUnit.SP.toPixels(this)
/**
* Converts the given Int SP value into Pixels
*/
@get:Px
val Int.sp: Int get() = this.toFloat().sp.toInt()

View File

@@ -0,0 +1,85 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.usage.StorageStatsManager
import android.content.Context
import android.os.Build
import android.os.StatFs
import android.os.storage.StorageManager
import androidx.annotation.RequiresApi
import org.signal.core.util.logging.Log
object DiskUtil {
private val TAG = Log.tag(DiskUtil::class)
/**
* Gets the remaining storage usable by the application.
*
* @param context The application context
*/
@JvmStatic
fun getAvailableSpace(context: Context): ByteSize {
return if (Build.VERSION.SDK_INT >= 26) {
getAvailableStorageBytesApi26(context).bytes
} else {
return getAvailableStorageBytesLegacy(context).bytes
}
}
/**
* Gets the total disk size of the volume used by the application.
*
* @param context The application context
*/
@JvmStatic
fun getTotalDiskSize(context: Context): ByteSize {
return if (Build.VERSION.SDK_INT >= 26) {
getTotalDiskSizeApi26(context).bytes
} else {
return getTotalDiskSizeLegacy(context).bytes
}
}
@RequiresApi(26)
private fun getAvailableStorageBytesApi26(context: Context): Long {
val storageManager = context.getSystemService(Context.STORAGE_SERVICE) as StorageManager
val storageStatsManager = context.getSystemService(Context.STORAGE_STATS_SERVICE) as StorageStatsManager
val appStorageUuid = storageManager.getUuidForPath(context.filesDir)
return try {
storageStatsManager.getFreeBytes(appStorageUuid)
} catch (e: Throwable) {
Log.w(TAG, "Hit a weird platform bug! Falling back to legacy.", e)
getAvailableStorageBytesLegacy(context)
}
}
private fun getAvailableStorageBytesLegacy(context: Context): Long {
val stat = StatFs(context.filesDir.absolutePath)
return stat.availableBytes
}
@RequiresApi(26)
private fun getTotalDiskSizeApi26(context: Context): Long {
val storageManager = context.getSystemService(Context.STORAGE_SERVICE) as StorageManager
val storageStatsManager = context.getSystemService(Context.STORAGE_STATS_SERVICE) as StorageStatsManager
val appStorageUuid = storageManager.getUuidForPath(context.filesDir)
return try {
storageStatsManager.getTotalBytes(appStorageUuid)
} catch (e: Throwable) {
Log.w(TAG, "Hit a weird platform bug! Falling back to legacy.", e)
getTotalDiskSizeLegacy(context)
}
}
private fun getTotalDiskSizeLegacy(context: Context): Long {
val stat = StatFs(context.filesDir.absolutePath)
return stat.totalBytes
}
}

View File

@@ -0,0 +1,78 @@
package org.signal.core.util;
import android.annotation.SuppressLint;
import android.annotation.TargetApi;
import android.graphics.PorterDuff;
import android.graphics.PorterDuffColorFilter;
import android.graphics.drawable.Drawable;
import android.os.Build;
import android.text.InputFilter;
import android.widget.EditText;
import android.widget.TextView;
import androidx.annotation.ColorInt;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public final class EditTextUtil {
private EditTextUtil() {
}
public static void addGraphemeClusterLimitFilter(EditText text, int maximumGraphemes) {
List<InputFilter> filters = new ArrayList<>(Arrays.asList(text.getFilters()));
filters.add(new GraphemeClusterLimitFilter(maximumGraphemes));
text.setFilters(filters.toArray(new InputFilter[0]));
}
public static void setCursorColor(@NonNull EditText text, @ColorInt int colorInt) {
if (Build.VERSION.SDK_INT >= 29) {
Drawable drawable = text.getTextCursorDrawable();
if (drawable == null) {
return;
}
Drawable cursorDrawable = drawable.mutate();
cursorDrawable.setColorFilter(new PorterDuffColorFilter(colorInt, PorterDuff.Mode.SRC_IN));
text.setTextCursorDrawable(cursorDrawable);
} else {
setCursorColorViaReflection(text, colorInt);
}
}
/**
* Note: This is only ever called in API 28 and less.
*/
@SuppressLint("SoonBlockedPrivateApi")
private static void setCursorColorViaReflection(EditText editText, int color) {
try {
Field fCursorDrawableRes = TextView.class.getDeclaredField("mCursorDrawableRes");
fCursorDrawableRes.setAccessible(true);
int mCursorDrawableRes = fCursorDrawableRes.getInt(editText);
Field fEditor = TextView.class.getDeclaredField("mEditor");
fEditor.setAccessible(true);
Object editor = fEditor.get(editText);
Class<?> clazz = editor.getClass();
Field fCursorDrawable = clazz.getDeclaredField("mCursorDrawable");
fCursorDrawable.setAccessible(true);
Drawable[] drawables = new Drawable[2];
drawables[0] = editText.getContext().getResources().getDrawable(mCursorDrawableRes);
drawables[1] = editText.getContext().getResources().getDrawable(mCursorDrawableRes);
drawables[0].setColorFilter(color, PorterDuff.Mode.SRC_IN);
drawables[1].setColorFilter(color, PorterDuff.Mode.SRC_IN);
fCursorDrawable.set(editor, drawables);
} catch (Throwable ignored) {
}
}
}

View File

@@ -0,0 +1,114 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlin.math.ceil
import kotlin.math.floor
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.DurationUnit
/**
* Used to track performance metrics for large clusters of similar events.
* For instance, if you were doing a backup restore and had to important many different kinds of data in an unknown order, you could
* use this to learn stats around how long each kind of data takes to import.
*
* It is assumed that all events are happening serially with no delays in between.
*
* The timer tracks things at nanosecond granularity, but presents data as fractional milliseconds for readability.
*/
class EventTimer {
private val durationsByGroup: MutableMap<String, MutableList<Long>> = mutableMapOf()
private var startTime = System.nanoTime()
private var lastTimeNanos: Long = startTime
fun reset() {
startTime = System.nanoTime()
lastTimeNanos = startTime
durationsByGroup.clear()
}
/**
* Indicates an event in the specified group has finished.
*/
fun emit(group: String) {
val now = System.nanoTime()
val duration = now - lastTimeNanos
durationsByGroup.getOrPut(group) { mutableListOf() } += duration
lastTimeNanos = now
}
/**
* Stops the timer and returns a mapping of group -> [EventMetrics], which will tell you various statistics around timings for that group.
*/
fun stop(): EventTimerResults {
val data: Map<String, EventMetrics> = durationsByGroup
.mapValues { entry ->
val sorted: List<Long> = entry.value.sorted()
EventMetrics(
totalTime = sorted.sum().nanoseconds.toDouble(DurationUnit.MILLISECONDS),
eventCount = sorted.size,
sortedDurationNanos = sorted
)
}
return EventTimerResults(data)
}
class EventTimerResults(data: Map<String, EventMetrics>) : Map<String, EventMetrics> by data {
val summary by lazy {
val builder = StringBuilder()
builder.append("[overall] totalTime: ${data.values.map { it.totalTime }.sum().roundedString(2)} ")
for (entry in data) {
builder.append("[${entry.key}] totalTime: ${entry.value.totalTime.roundedString(2)}, count: ${entry.value.eventCount}, p50: ${entry.value.p(50)}, p90: ${entry.value.p(90)}, p99: ${entry.value.p(99)} ")
}
builder.toString()
}
}
data class EventMetrics(
/** The sum of all event durations, in fractional milliseconds. */
val totalTime: Double,
/** Total number of events observed. */
val eventCount: Int,
private val sortedDurationNanos: List<Long>
) {
/**
* Returns the percentile of the duration data (e.g. p50, p90) as a formatted string containing fractional milliseconds rounded to the requested number of decimal places.
*/
fun p(percentile: Int, decimalPlaces: Int = 2): String {
return pNanos(percentile).nanoseconds.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)
}
private fun pNanos(percentile: Int): Long {
if (sortedDurationNanos.isEmpty()) {
return 0L
}
val index: Float = (percentile / 100f) * (sortedDurationNanos.size - 1)
val lowerIndex: Int = floor(index).toInt()
val upperIndex: Int = ceil(index).toInt()
if (lowerIndex == upperIndex) {
return sortedDurationNanos[lowerIndex]
}
val interpolationFactor: Float = index - lowerIndex
val lowerValue: Long = sortedDurationNanos[lowerIndex]
val upperValue: Long = sortedDurationNanos[upperIndex]
return floor(lowerValue + (upperValue - lowerValue) * interpolationFactor).toLong()
}
}
}

View File

@@ -0,0 +1,96 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
import org.signal.core.util.logging.Scrubber;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
public final class ExceptionUtil {
private ExceptionUtil() {}
/**
* Joins the stack trace of the inferred call site with the original exception. This is
* useful for when exceptions are thrown inside of asynchronous systems (like runnables in an
* executor) where you'd otherwise lose important parts of the stack trace. This lets you save a
* throwable at the entry point, and then combine it with any caught exceptions later.
*
* The resulting stack trace will look like this:
*
* Inferred
* Stack
* Trace
* [[ ↑↑ Inferred Trace ↑↑ ]]
* [[ ↓↓ Original Trace ↓↓ ]]
* Original
* Stack
* Trace
*
* @return The provided original exception, for convenience.
*/
public static <E extends Throwable> E joinStackTrace(@NonNull E original, @NonNull Throwable inferred) {
StackTraceElement[] combinedTrace = joinStackTrace(original.getStackTrace(), inferred.getStackTrace());
original.setStackTrace(combinedTrace);
return original;
}
/**
* See {@link #joinStackTrace(Throwable, Throwable)}
*/
public static StackTraceElement[] joinStackTrace(@NonNull StackTraceElement[] originalTrace, @NonNull StackTraceElement[] inferredTrace) {
StackTraceElement[] combinedTrace = new StackTraceElement[originalTrace.length + inferredTrace.length + 2];
System.arraycopy(originalTrace, 0, combinedTrace, 0, originalTrace.length);
combinedTrace[originalTrace.length] = new StackTraceElement("[[ ↑↑ Original Trace ↑↑ ]]", "", "", 0);
combinedTrace[originalTrace.length + 1] = new StackTraceElement("[[ ↓↓ Inferred Trace ↓↓ ]]", "", "", 0);
System.arraycopy(inferredTrace, 0, combinedTrace, originalTrace.length + 2, inferredTrace.length);
return combinedTrace;
}
/**
* Joins the stack trace with the exception's {@link Throwable#getMessage()}.
*
* The resulting stack trace will look like this:
*
* Original
* Stack
* Trace
* [[ ↑↑ Original Trace ↑↑ ]]
* [[ ↓↓ Exception Message ↓↓ ]]
* Exception Message
*
* @return The provided original exception, for convenience.
*/
public static @NonNull <E extends Throwable> E joinStackTraceAndMessage(@NonNull E original) {
StackTraceElement[] originalTrace = original.getStackTrace();
StackTraceElement[] combinedTrace = new StackTraceElement[originalTrace.length + 3];
System.arraycopy(originalTrace, 0, combinedTrace, 0, originalTrace.length);
String message = Scrubber.scrub(original.getMessage() != null ? original.getMessage() : "null").toString();
if (message.startsWith("Context.startForegroundService")) {
try {
String service = message.substring(message.lastIndexOf('.') + 1, message.length() - 1);
message = service + " did not call startForeground";
} catch (Exception ignored) {}
}
combinedTrace[originalTrace.length] = new StackTraceElement("[[ ↑↑ Original Trace ↑↑ ]]", "", "", 0);
combinedTrace[originalTrace.length + 1] = new StackTraceElement("[[ ↓↓ Exception Message ↓↓ ]]", "", "", 0);
combinedTrace[originalTrace.length + 2] = new StackTraceElement(message, "", "", 0);
original.setStackTrace(combinedTrace);
return original;
}
public static @NonNull String convertThrowableToString(@NonNull Throwable throwable) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
throwable.printStackTrace(new PrintStream(outputStream));
return outputStream.toString();
}
}

View File

@@ -0,0 +1,39 @@
package org.signal.core.util
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.PorterDuff
import kotlin.math.abs
object FontUtil {
private const val SAMPLE_EMOJI = "\uD83C\uDF0D" // 🌍
/**
* Certain platforms cannot render emoji above a certain font size.
*
* This will attempt to render an emoji at the specified font size and tell you if it's possible.
* It does this by rendering an emoji into a 1x1 bitmap and seeing if the resulting pixel is non-transparent.
*
* https://stackoverflow.com/a/50988748
*/
@JvmStatic
fun canRenderEmojiAtFontSize(size: Float): Boolean {
val bitmap: Bitmap = Bitmap.createBitmap(1, 1, Bitmap.Config.ARGB_8888)
val canvas = Canvas(bitmap)
val paint = Paint()
paint.textSize = size
paint.textAlign = Paint.Align.CENTER
val ascent: Float = abs(paint.ascent())
val descent: Float = abs(paint.descent())
val halfHeight = (ascent + descent) / 2.0f
canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR)
canvas.drawText(SAMPLE_EMOJI, 0.5f, 0.5f + halfHeight - descent, paint)
return bitmap.getPixel(0, 0) != 0
}
}

View File

@@ -0,0 +1,54 @@
package org.signal.core.util;
import android.text.InputFilter;
import android.text.Spanned;
import org.signal.core.util.logging.Log;
/**
* This filter will constrain edits not to make the number of character breaks of the text
* greater than the specified maximum.
* <p>
* This means it will limit to a maximum number of grapheme clusters.
*/
public final class GraphemeClusterLimitFilter implements InputFilter {
private static final String TAG = Log.tag(GraphemeClusterLimitFilter.class);
private final BreakIteratorCompat breakIteratorCompat;
private final int max;
public GraphemeClusterLimitFilter(int max) {
this.breakIteratorCompat = BreakIteratorCompat.getInstance();
this.max = max;
}
@Override
public CharSequence filter(CharSequence source, int start, int end, Spanned dest, int dstart, int dend) {
CharSequence sourceFragment = source.subSequence(start, end);
CharSequence head = dest.subSequence(0, dstart);
CharSequence tail = dest.subSequence(dend, dest.length());
breakIteratorCompat.setText(String.format("%s%s%s", head, sourceFragment, tail));
int length = breakIteratorCompat.countBreaks();
if (length > max) {
breakIteratorCompat.setText(sourceFragment);
int sourceLength = breakIteratorCompat.countBreaks();
CharSequence trimmedSource = breakIteratorCompat.take(sourceLength - (length - max));
breakIteratorCompat.setText(String.format("%s%s%s", head, trimmedSource, tail));
int newExpectedCount = breakIteratorCompat.countBreaks();
if (newExpectedCount > max) {
Log.w(TAG, "Failed to create string under the required length " + newExpectedCount);
return "";
}
return trimmedSource;
}
return source;
}
}

View File

@@ -0,0 +1,23 @@
package org.signal.core.util
import android.content.Intent
import android.os.Build
import android.os.Parcelable
fun <T : Parcelable> Intent.getParcelableExtraCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableExtra(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableExtra(key)
}
}
fun <T : Parcelable> Intent.getParcelableArrayListExtraCompat(key: String, clazz: Class<T>): ArrayList<T>? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableArrayListExtra(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableArrayListExtra(key)
}
}

View File

@@ -0,0 +1,23 @@
package org.signal.core.util;
import java.util.concurrent.LinkedBlockingDeque;
public class LinkedBlockingLifoQueue<E> extends LinkedBlockingDeque<E> {
@Override
public void put(E runnable) throws InterruptedException {
super.putFirst(runnable);
}
@Override
public boolean add(E runnable) {
super.addFirst(runnable);
return true;
}
@Override
public boolean offer(E runnable) {
super.addFirst(runnable);
return true;
}
}

View File

@@ -0,0 +1,39 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;
public final class ListUtil {
private ListUtil() {}
public static <E> List<List<E>> chunk(@NonNull List<E> list, int chunkSize) {
List<List<E>> chunks = new ArrayList<>(list.size() / chunkSize);
for (int i = 0; i < list.size(); i += chunkSize) {
List<E> chunk = list.subList(i, Math.min(list.size(), i + chunkSize));
chunks.add(chunk);
}
return chunks;
}
@SafeVarargs
public static <T> List<T> concat(Collection<T>... items) {
final List<T> concat = new ArrayList<>(Stream.of(items).map(Collection::size).reduce(0, Integer::sum));
for (Collection<T> list : items) {
concat.addAll(list);
}
return concat;
}
public static <T> List<T> emptyIfNull(List<T> list) {
return list == null ? Collections.emptyList() : list;
}
}

View File

@@ -0,0 +1,30 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import java.util.Map;
import java.util.function.Function;
public final class MapUtil {
private MapUtil() {}
@NonNull
public static <K, V> V getOrDefault(@NonNull Map<K, V> map, @NonNull K key, @NonNull V defaultValue) {
if (Build.VERSION.SDK_INT >= 24) {
//noinspection ConstantConditions
return map.getOrDefault(key, defaultValue);
} else {
V v = map.get(key);
return v == null ? defaultValue : v;
}
}
@NonNull
public static <K, V, M> M mapOrDefault(@NonNull Map<K, V> map, @NonNull K key, @NonNull Function<V, M> mapper, @NonNull M defaultValue) {
V v = map.get(key);
return v == null ? defaultValue : mapper.apply(v);
}
}

View File

@@ -0,0 +1,151 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.ActivityManager
import android.content.Context
import android.os.Debug
import android.os.Handler
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
object MemoryTracker {
private val TAG = Log.tag(MemoryTracker::class.java)
private val runtime: Runtime = Runtime.getRuntime()
private val activityMemoryInfo: ActivityManager.MemoryInfo = ActivityManager.MemoryInfo()
private val debugMemoryInfo: Debug.MemoryInfo = Debug.MemoryInfo()
private val handler: Handler = Handler(SignalExecutors.getAndStartHandlerThread("MemoryTracker", ThreadUtil.PRIORITY_BACKGROUND_THREAD).looper)
private val POLLING_INTERVAL = 5.seconds.inWholeMilliseconds
private var running = false
private lateinit var previousAppHeadUsage: AppHeapUsage
private var increaseMemoryCount = 0
@JvmStatic
fun start() {
Log.d(TAG, "Beginning memory monitoring.")
running = true
previousAppHeadUsage = getAppJvmHeapUsage()
increaseMemoryCount = 0
handler.postDelayed(this::poll, POLLING_INTERVAL)
}
@JvmStatic
fun stop() {
Log.d(TAG, "Ending memory monitoring.")
running = false
handler.removeCallbacksAndMessages(null)
}
fun poll() {
val currentHeapUsage = getAppJvmHeapUsage()
if (currentHeapUsage.currentTotalBytes != previousAppHeadUsage.currentTotalBytes) {
if (currentHeapUsage.currentTotalBytes > previousAppHeadUsage.currentTotalBytes) {
Log.d(TAG, "The system increased our app JVM heap from ${previousAppHeadUsage.currentTotalBytes.byteDisplay()} to ${currentHeapUsage.currentTotalBytes.byteDisplay()}")
} else {
Log.d(TAG, "The system decreased our app JVM heap from ${previousAppHeadUsage.currentTotalBytes.byteDisplay()} to ${currentHeapUsage.currentTotalBytes.byteDisplay()}")
}
}
if (currentHeapUsage.usedBytes >= previousAppHeadUsage.usedBytes) {
increaseMemoryCount++
} else {
Log.d(TAG, "Used memory has decreased from ${previousAppHeadUsage.usedBytes.byteDisplay()} to ${currentHeapUsage.usedBytes.byteDisplay()}")
increaseMemoryCount = 0
}
if (increaseMemoryCount > 0 && increaseMemoryCount % 5 == 0) {
Log.d(TAG, "Used memory has increased or stayed the same for the last $increaseMemoryCount intervals (${increaseMemoryCount * POLLING_INTERVAL.milliseconds.inWholeSeconds} seconds). Using: ${currentHeapUsage.usedBytes.byteDisplay()}, Free: ${currentHeapUsage.freeBytes.byteDisplay()}, CurrentTotal: ${currentHeapUsage.currentTotalBytes.byteDisplay()}, MaxPossible: ${currentHeapUsage.maxPossibleBytes.byteDisplay()}")
}
previousAppHeadUsage = currentHeapUsage
if (running) {
handler.postDelayed(this::poll, POLLING_INTERVAL)
}
}
/**
* Gives us basic memory usage data for our app JVM heap usage. Very fast, ~10 micros on an emulator.
*/
fun getAppJvmHeapUsage(): AppHeapUsage {
return AppHeapUsage(
freeBytes = runtime.freeMemory(),
currentTotalBytes = runtime.totalMemory(),
maxPossibleBytes = runtime.maxMemory()
)
}
/**
* This gives us details stats, but it takes an appreciable amount of time. On an emulator, it can take ~30ms.
* As a result, we don't want to be calling this regularly for most users.
*/
fun getDetailedMemoryStats(): DetailedMemoryStats {
Debug.getMemoryInfo(debugMemoryInfo)
return DetailedMemoryStats(
appJavaHeapUsageKb = debugMemoryInfo.getMemoryStat("summary.java-heap")?.toLongOrNull(),
appNativeHeapUsageKb = debugMemoryInfo.getMemoryStat("summary.native-heap")?.toLongOrNull(),
codeUsageKb = debugMemoryInfo.getMemoryStat("summary.code")?.toLongOrNull(),
stackUsageKb = debugMemoryInfo.getMemoryStat("summary.stack")?.toLongOrNull(),
graphicsUsageKb = debugMemoryInfo.getMemoryStat("summary.graphics")?.toLongOrNull(),
appOtherUsageKb = debugMemoryInfo.getMemoryStat("summary.private-other")?.toLongOrNull()
)
}
fun getSystemNativeMemoryUsage(context: Context): NativeMemoryUsage {
val activityManager: ActivityManager = context.getSystemService(Context.ACTIVITY_SERVICE) as ActivityManager
activityManager.getMemoryInfo(activityMemoryInfo)
return NativeMemoryUsage(
freeBytes = activityMemoryInfo.availMem,
totalBytes = activityMemoryInfo.totalMem,
lowMemory = activityMemoryInfo.lowMemory,
lowMemoryThreshold = activityMemoryInfo.threshold
)
}
private fun Long.byteDisplay(): String {
return "$this (${this.bytes.inMebiBytes.roundedString(2)} MiB)"
}
data class AppHeapUsage(
/** The number of bytes that are free to use. */
val freeBytes: Long,
/** The current total number of bytes our app could use. This can increase over time as the system increases our allocation. */
val currentTotalBytes: Long,
/** The maximum number of bytes that our app could ever be given. */
val maxPossibleBytes: Long
) {
/** The number of bytes that our app is currently using. */
val usedBytes: Long
get() = currentTotalBytes - freeBytes
}
data class NativeMemoryUsage(
val freeBytes: Long,
val totalBytes: Long,
val lowMemory: Boolean,
val lowMemoryThreshold: Long
) {
val usedBytes: Long
get() = totalBytes - freeBytes
}
data class DetailedMemoryStats(
val appJavaHeapUsageKb: Long?,
val appNativeHeapUsageKb: Long?,
val codeUsageKb: Long?,
val graphicsUsageKb: Long?,
val stackUsageKb: Long?,
val appOtherUsageKb: Long?
)
}

View File

@@ -0,0 +1,129 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.util.Queue
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import kotlin.math.ceil
import kotlin.math.floor
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.DurationUnit
/**
* Used to track performance metrics for large clusters of similar events that are happening simultaneously.
*
* Very similar to [EventTimer], but with no assumptions around threading,
*
* The timer tracks things at nanosecond granularity, but presents data as fractional milliseconds for readability.
*/
class ParallelEventTimer {
val durationsByGroup: MutableMap<String, Queue<Long>> = ConcurrentHashMap()
private var startTime = System.nanoTime()
fun reset() {
durationsByGroup.clear()
startTime = System.nanoTime()
}
/**
* Begin an event associated with a group. You must call [EventStopper.stopEvent] on the returned object in order to indicate the action has completed.
*/
fun beginEvent(group: String): EventStopper {
val start = System.nanoTime()
return EventStopper {
val duration = System.nanoTime() - start
durationsByGroup.computeIfAbsent(group) { ConcurrentLinkedQueue() } += duration
}
}
/**
* Time an event associated with a group.
*/
inline fun <E> timeEvent(group: String, operation: () -> E): E {
val start = System.nanoTime()
val result = operation()
val duration = System.nanoTime() - start
durationsByGroup.computeIfAbsent(group) { ConcurrentLinkedQueue() } += duration
return result
}
/**
* Stops the timer and returns a mapping of group -> [EventMetrics], which will tell you various statistics around timings for that group.
* It is assumed that all events have been stopped by the time this has been called.
*/
fun stop(): EventTimerResults {
val totalDuration = System.nanoTime() - startTime
val data: Map<String, EventMetrics> = durationsByGroup
.mapValues { entry ->
val sorted: List<Long> = entry.value.sorted()
EventMetrics(
totalEventTime = sorted.sum().nanoseconds.toDouble(DurationUnit.MILLISECONDS),
eventCount = sorted.size,
sortedDurationNanos = sorted
)
}
return EventTimerResults(totalDuration.nanoseconds.toDouble(DurationUnit.MILLISECONDS), data)
}
class EventTimerResults(totalWallTime: Double, data: Map<String, EventMetrics>) : Map<String, EventMetrics> by data {
val summary by lazy {
val builder = StringBuilder()
builder.append("[overall] totalWallTime: ${totalWallTime.roundedString(2)}, totalEventTime: ${data.values.map { it.totalEventTime}.sum().roundedString(2)} ")
for (entry in data) {
builder.append("[${entry.key}] totalEventTime: ${entry.value.totalEventTime.roundedString(2)}, count: ${entry.value.eventCount}, p50: ${entry.value.p(50)}, p90: ${entry.value.p(90)}, p99: ${entry.value.p(99)} ")
}
builder.toString()
}
}
fun interface EventStopper {
fun stopEvent()
}
data class EventMetrics(
/** The sum of all event times, in fractional milliseconds. If running operations in parallel, this will likely be larger than [totalWallTime]. */
val totalEventTime: Double,
/** Total number of events observed. */
val eventCount: Int,
private val sortedDurationNanos: List<Long>
) {
/**
* Returns the percentile of the duration data (e.g. p50, p90) as a formatted string containing fractional milliseconds rounded to the requested number of decimal places.
*/
fun p(percentile: Int, decimalPlaces: Int = 2): String {
return pNanos(percentile).nanoseconds.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)
}
private fun pNanos(percentile: Int): Long {
if (sortedDurationNanos.isEmpty()) {
return 0L
}
val index: Float = (percentile / 100f) * (sortedDurationNanos.size - 1)
val lowerIndex: Int = floor(index).toInt()
val upperIndex: Int = ceil(index).toInt()
if (lowerIndex == upperIndex) {
return sortedDurationNanos[lowerIndex]
}
val interpolationFactor: Float = index - lowerIndex
val lowerValue: Long = sortedDurationNanos[lowerIndex]
val upperValue: Long = sortedDurationNanos[upperIndex]
return floor(lowerValue + (upperValue - lowerValue) * interpolationFactor).toLong()
}
}
}

View File

@@ -0,0 +1,23 @@
package org.signal.core.util
import android.os.Build
import android.os.Parcel
import android.os.Parcelable
fun <T : Parcelable> Parcel.readParcelableCompat(clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.readParcelable(clazz.classLoader, clazz)
} else {
@Suppress("DEPRECATION")
this.readParcelable(clazz.classLoader)
}
}
fun <T : java.io.Serializable> Parcel.readSerializableCompat(clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.readSerializable(clazz.classLoader, clazz)
} else {
@Suppress("DEPRECATION", "UNCHECKED_CAST")
this.readSerializable() as T
}
}

View File

@@ -0,0 +1,47 @@
package org.signal.core.util
import android.app.PendingIntent
import android.os.Build
/**
* Wrapper class for lower level API compatibility with the new Pending Intents flags.
*
* This is meant to be a replacement to using PendingIntent flags independently, and should
* end up being the only place in our codebase that accesses these values.
*
* The "default" value is FLAG_MUTABLE
*/
object PendingIntentFlags {
@JvmStatic
fun updateCurrent(): Int {
return mutable() or PendingIntent.FLAG_UPDATE_CURRENT
}
@JvmStatic
fun cancelCurrent(): Int {
return mutable() or PendingIntent.FLAG_CANCEL_CURRENT
}
/**
* Flag indicating that this [PendingIntent] can be used only once. After [PendingIntent.send] is called on it,
* it will be automatically canceled for you and any future attempt to send through it will fail.
*/
@JvmStatic
fun oneShot(): Int {
return immutable() or PendingIntent.FLAG_ONE_SHOT
}
/**
* The backwards compatible "default" value for pending intent flags.
*/
@JvmStatic
fun mutable(): Int {
return if (Build.VERSION.SDK_INT >= 31) PendingIntent.FLAG_MUTABLE else 0
}
@JvmStatic
fun immutable(): Int {
return PendingIntent.FLAG_IMMUTABLE
}
}

View File

@@ -0,0 +1,33 @@
package org.signal.core.util;
import android.content.Context;
import android.content.res.Configuration;
import android.content.res.Resources;
import androidx.annotation.NonNull;
import androidx.annotation.StringRes;
import java.util.Locale;
/**
* Gives access to English strings.
*/
public final class ResourceUtil {
private ResourceUtil() {
}
public static Resources getEnglishResources(@NonNull Context context) {
return getResources(context, Locale.ENGLISH);
}
public static Resources getResources(@NonNull Context context, @NonNull Locale locale) {
Configuration configurationLocal = context.getResources().getConfiguration();
Configuration configurationEn = new Configuration(configurationLocal);
configurationEn.setLocale(locale);
return context.createConfigurationContext(configurationEn)
.getResources();
}
}

View File

@@ -0,0 +1,61 @@
package org.signal.core.util
/**
* A Result that allows for generic definitions of success/failure values.
*/
sealed class Result<out S, out F> {
data class Failure<out F>(val failure: F) : Result<Nothing, F>()
data class Success<out S>(val success: S) : Result<S, Nothing>()
companion object {
@JvmStatic
fun <S> success(value: S) = Success(value)
@JvmStatic
fun <F> failure(value: F) = Failure(value)
}
/**
* Maps an Result<S, F> to an Result<T, F>. Failure values will pass through, while
* right values will be operated on by the parameter.
*/
fun <T> map(onSuccess: (S) -> T): Result<T, F> {
return when (this) {
is Failure -> this
is Success -> success(onSuccess(success))
}
}
/**
* Allows the caller to operate on the Result such that the correct function is applied
* to the value it contains.
*/
fun <T> either(
onSuccess: (S) -> T,
onFailure: (F) -> T
): T {
return when (this) {
is Success -> onSuccess(success)
is Failure -> onFailure(failure)
}
}
}
/**
* Maps an Result<L, R> to an Result<L, T>. Failure values will pass through, while
* right values will be operated on by the parameter.
*
* Note this is an extension method in order to make the generics happy.
*/
fun <T, S, F> Result<S, F>.flatMap(onSuccess: (S) -> Result<T, F>): Result<T, F> {
return when (this) {
is Result.Success -> onSuccess(success)
is Result.Failure -> this
}
}
/**
* Try is a specialization of Result where the Failure is fixed to Throwable.
*/
typealias Try<S> = Result<S, Throwable>

View File

@@ -0,0 +1,658 @@
package org.signal.core.util
import android.content.ContentValues
import android.database.Cursor
import android.database.sqlite.SQLiteDatabase
import androidx.core.content.contentValuesOf
import androidx.sqlite.db.SupportSQLiteDatabase
import androidx.sqlite.db.SupportSQLiteQueryBuilder
import androidx.sqlite.db.SupportSQLiteStatement
import org.signal.core.util.SqlUtil.ForeignKeyViolation
import org.signal.core.util.logging.Log
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
private val TAG = "SQLiteDatabaseExtensions"
/**
* Begins a transaction on the `this` database, runs the provided [block] providing the `this` value as it's argument
* within the transaction, and then ends the transaction successfully.
*
* @return The value returned by [block] if any
*/
inline fun <T : SupportSQLiteDatabase, R> T.withinTransaction(block: (T) -> R): R {
beginTransaction()
try {
val toReturn = block(this)
if (inTransaction()) {
setTransactionSuccessful()
}
return toReturn
} finally {
if (inTransaction()) {
endTransaction()
}
}
}
fun SupportSQLiteDatabase.getTableRowCount(table: String): Int {
return this.query("SELECT COUNT(*) FROM $table").use {
if (it.moveToFirst()) {
it.getInt(0)
} else {
0
}
}
}
fun SupportSQLiteDatabase.getAllTables(): List<String> {
return SqlUtil.getAllTables(this)
}
/**
* Returns a list of objects that represent the table definitions in the database. Basically the table name and then the SQL that was used to create it.
*/
fun SupportSQLiteDatabase.getAllTableDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL AND name != ?", "table", "sqlite_sequence")
.run()
.readToList { cursor ->
CreateStatement(
name = cursor.requireNonNullString("name"),
statement = cursor.requireNonNullString("sql").replace(" ", "")
)
}
.filterNot { it.name.startsWith("sqlite_stat") }
.sortedBy { it.name }
}
/**
* Returns a list of objects that represent the index definitions in the database. Basically the index name and then the SQL that was used to create it.
*/
fun SupportSQLiteDatabase.getAllIndexDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL", "index")
.run()
.readToList { cursor ->
CreateStatement(
name = cursor.requireNonNullString("name"),
statement = cursor.requireNonNullString("sql")
)
}
.sortedBy { it.name }
}
/**
* Retrieves the names of all triggers, sorted alphabetically.
*/
fun SupportSQLiteDatabase.getAllTriggerDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL", "trigger")
.run()
.readToList {
CreateStatement(
name = it.requireNonNullString("name"),
statement = it.requireNonNullString("sql")
)
}
.sortedBy { it.name }
}
fun SupportSQLiteDatabase.getForeignKeys(): List<ForeignKeyConstraint> {
return SqlUtil.getAllTables(this)
.map { table ->
this.query("PRAGMA foreign_key_list($table)").readToList { cursor ->
ForeignKeyConstraint(
table = table,
column = cursor.requireNonNullString("from"),
dependsOnTable = cursor.requireNonNullString("table"),
dependsOnColumn = cursor.requireNonNullString("to"),
onDelete = cursor.requireString("on_delete") ?: "NOTHING"
)
}
}
.flatten()
}
fun SupportSQLiteDatabase.areForeignKeyConstraintsEnabled(): Boolean {
return this.query("PRAGMA foreign_keys", arrayOf()).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) != 0
}
}
/**
* Provides a list of all foreign key violations present.
* If a [targetTable] is specified, results will be limited to that table specifically.
* Otherwise, the check will be performed across all tables.
*/
@JvmOverloads
fun SupportSQLiteDatabase.getForeignKeyViolations(targetTable: String? = null): List<ForeignKeyViolation> {
return SqlUtil.getForeignKeyViolations(this, targetTable)
}
/**
* For tables that have an autoincrementing primary key, this will reset the key to start back at 1.
* IMPORTANT: This is quite dangerous! Only do this if you're effectively resetting the entire database.
*/
fun SupportSQLiteDatabase.resetAutoIncrementValue(targetTable: String) {
SqlUtil.resetAutoIncrementValue(this, targetTable)
}
/**
* Does a full WAL checkpoint (TRUNCATE mode, where the log is for sure flushed and the log is zero'd out).
* Will try up to [maxAttempts] times. Can technically fail if the database is too active and the checkpoint
* can't complete in a reasonable amount of time.
*
* See: https://www.sqlite.org/pragma.html#pragma_wal_checkpoint
*/
fun SupportSQLiteDatabase.fullWalCheckpoint(maxAttempts: Int = 3): Boolean {
var attempts = 0
while (attempts < maxAttempts) {
if (this.walCheckpoint()) {
return true
}
attempts++
}
return false
}
private fun SupportSQLiteDatabase.walCheckpoint(): Boolean {
return this.query("PRAGMA wal_checkpoint(TRUNCATE)").use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 0
}
}
fun SupportSQLiteDatabase.getIndexes(): List<Index> {
return this.query("SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name ASC").readToList { cursor ->
val indexName = cursor.requireNonNullString("name")
Index(
name = indexName,
table = cursor.requireNonNullString("tbl_name"),
columns = this.query("PRAGMA index_info($indexName)").readToList { it.requireNonNullString("name") }
)
}
}
fun SupportSQLiteDatabase.forceForeignKeyConstraintsEnabled(enabled: Boolean, timeout: Duration = 10.seconds) {
val startTime = System.currentTimeMillis()
while (true) {
try {
this.setForeignKeyConstraintsEnabled(enabled)
break
} catch (e: IllegalStateException) {
if (System.currentTimeMillis() - startTime > timeout.inWholeMilliseconds) {
throw IllegalStateException("Failed to force foreign keys to '$enabled' within the timeout of $timeout", e)
}
Log.w(TAG, "Failed to set foreign keys because we're in a transaction. Waiting 100ms then trying again.")
ThreadUtil.sleep(100)
}
}
}
/**
* Checks if a row exists that matches the query.
*/
fun SupportSQLiteDatabase.exists(table: String): ExistsBuilderPart1 {
return ExistsBuilderPart1(this, table)
}
/**
* Begins a SELECT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.select(vararg columns: String): SelectBuilderPart1 {
return SelectBuilderPart1(this, arrayOf(*columns))
}
/**
* Begins a COUNT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.count(): SelectBuilderPart1 {
return SelectBuilderPart1(this, SqlUtil.COUNT)
}
/**
* Begins an UPDATE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to update all items in the table, use [updateAll].
*/
fun SupportSQLiteDatabase.update(tableName: String): UpdateBuilderPart1 {
return UpdateBuilderPart1(this, tableName)
}
fun SupportSQLiteDatabase.updateAll(tableName: String): UpdateAllBuilderPart1 {
return UpdateAllBuilderPart1(this, tableName)
}
/**
* Begins a DELETE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to delete all items in the table, use [deleteAll].
*/
fun SupportSQLiteDatabase.delete(tableName: String): DeleteBuilderPart1 {
return DeleteBuilderPart1(this, tableName)
}
/**
* Deletes all data in the table.
*/
fun SupportSQLiteDatabase.deleteAll(tableName: String): Int {
return this.delete(tableName, null, arrayOfNulls<String>(0))
}
/**
* Begins an INSERT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 {
return InsertBuilderPart1(this, tableName)
}
/**
* Bind an arbitrary value to an index. It will handle calling the correct bind method based on the class type.
* @param index The index you want to bind to. Important: Indexes start at 1, not 0.
*/
fun SupportSQLiteStatement.bindValue(index: Int, value: Any?) {
when (value) {
null -> this.bindNull(index)
is DatabaseId -> this.bindString(index, value.serialize())
is Boolean -> this.bindLong(index, value.toInt().toLong())
is ByteArray -> this.bindBlob(index, value)
is Number -> {
if (value.toLong() == value || value.toInt() == value || value.toShort() == value || value.toByte() == value) {
this.bindLong(index, value.toLong())
} else {
this.bindDouble(index, value.toDouble())
}
}
else -> this.bindString(index, value.toString())
}
}
class SelectBuilderPart1(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>
) {
fun from(tableName: String): SelectBuilderPart2 {
return SelectBuilderPart2(db, columns, tableName)
}
}
class SelectBuilderPart2(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): SelectBuilderPart3 {
return SelectBuilderPart3(db, columns, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): SelectBuilderPart3 {
return SelectBuilderPart3(db, columns, tableName, where, whereArgs)
}
fun orderBy(orderBy: String): SelectBuilderPart4a {
return SelectBuilderPart4a(db, columns, tableName, "", arrayOf(), orderBy)
}
fun limit(limit: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, "", arrayOf(), limit.toString())
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.create()
)
}
}
class SelectBuilderPart3(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun orderBy(orderBy: String): SelectBuilderPart4a {
return SelectBuilderPart4a(db, columns, tableName, where, whereArgs, orderBy)
}
fun limit(limit: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, limit.toString())
}
fun limit(limit: String): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, limit)
}
fun limit(limit: Int, offset: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, "$offset,$limit")
}
fun groupBy(groupBy: String): SelectBuilderPart4c {
return SelectBuilderPart4c(db, columns, tableName, where, whereArgs, groupBy)
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.create()
)
}
}
class SelectBuilderPart4a(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val orderBy: String
) {
fun limit(limit: Int): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit.toString())
}
fun limit(limit: String): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit)
}
fun limit(limit: Int, offset: Int): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, "$offset,$limit")
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.orderBy(orderBy)
.create()
)
}
}
class SelectBuilderPart4b(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val limit: String
) {
fun orderBy(orderBy: String): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit)
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.limit(limit)
.create()
)
}
}
class SelectBuilderPart4c(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val groupBy: String
) {
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.groupBy(groupBy)
.create()
)
}
}
class SelectBuilderPart5(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val orderBy: String,
private val limit: String
) {
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.orderBy(orderBy)
.limit(limit)
.create()
)
}
}
class UpdateBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): UpdateBuilderPart2 {
return UpdateBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): UpdateBuilderPart2 {
return UpdateBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class UpdateBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, whereArgs.toArgs())
}
fun where(where: String, whereArgs: Array<String>): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, whereArgs)
}
}
class UpdateBuilderPart3(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues,
private val where: String,
private val whereArgs: Array<out Any?>
) {
@JvmOverloads
fun run(): Int {
val query = StringBuilder("UPDATE $tableName SET ")
val contentValuesKeys = values.keySet()
for ((index, column) in contentValuesKeys.withIndex()) {
query.append(column).append(" = ?")
if (index < contentValuesKeys.size - 1) {
query.append(", ")
}
}
query.append(" WHERE ").append(where)
val statement = db.compileStatement(query.toString())
var bindIndex = 1
for (key in contentValuesKeys) {
statement.bindValue(bindIndex, values.get(key))
bindIndex++
}
for (arg in whereArgs) {
statement.bindValue(bindIndex, arg)
bindIndex++
}
return statement.use { it.executeUpdateDelete() }
}
}
class UpdateAllBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class UpdateAllBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
@JvmOverloads
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int {
return db.update(tableName, conflictStrategy, values, null, emptyArray<String>())
}
}
class DeleteBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, whereArgs)
}
}
class DeleteBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun run(): Int {
return db.delete(tableName, where, whereArgs)
}
}
class ExistsBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): ExistsBuilderPart2 {
return ExistsBuilderPart2(db, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): ExistsBuilderPart2 {
return ExistsBuilderPart2(db, tableName, where, whereArgs)
}
fun run(): Boolean {
return db.query("SELECT EXISTS(SELECT 1 FROM $tableName)", arrayOf()).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 1
}
}
}
class ExistsBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun run(): Boolean {
return db.query("SELECT EXISTS(SELECT 1 FROM $tableName WHERE $where)", SqlUtil.buildArgs(*whereArgs)).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 1
}
}
}
class InsertBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): InsertBuilderPart2 {
return InsertBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): InsertBuilderPart2 {
return InsertBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class InsertBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_IGNORE): Long {
return db.insert(tableName, conflictStrategy, values)
}
}
/**
* Helper function to massage passed-in arguments into a better form to give to the database.
*/
private fun Array<out Any?>.toArgs(): Array<Any?> {
return this
.map {
when (it) {
is DatabaseId -> it.serialize()
else -> it
}
}
.toTypedArray()
}
data class ForeignKeyConstraint(
val table: String,
val column: String,
val dependsOnTable: String,
val dependsOnColumn: String,
val onDelete: String
)
data class Index(
val name: String,
val table: String,
val columns: List<String>
)
data class CreateStatement(
val name: String,
val statement: String
)

View File

@@ -0,0 +1,41 @@
package org.signal.core.util
import android.content.ContentValues
import android.database.Cursor
/**
* Generalized serializer for finer control
*/
interface BaseSerializer<Data, Input, Output> {
fun serialize(data: Data): Output
fun deserialize(input: Input): Data
}
/**
* Generic serialization interface for use with database and store operations.
*/
interface Serializer<T, R> : BaseSerializer<T, R, R>
/**
* Serializer specifically for working with SQLite
*/
interface DatabaseSerializer<Data> : BaseSerializer<Data, Cursor, ContentValues>
interface StringSerializer<T> : Serializer<T, String>
interface IntSerializer<T> : Serializer<T, Int>
interface LongSerializer<T> : Serializer<T, Long>
interface ByteSerializer<T> : Serializer<T, ByteArray>
object StringStringSerializer : StringSerializer<String?> {
override fun serialize(data: String?): String {
return data ?: ""
}
override fun deserialize(data: String): String {
return data
}
}

View File

@@ -0,0 +1,238 @@
// Copyright 2010 Square, Inc.
// Modified 2020 Signal
package org.signal.core.util;
import android.hardware.Sensor;
import android.hardware.SensorEvent;
import android.hardware.SensorEventListener;
import android.hardware.SensorManager;
/**
* Detects phone shaking. If more than 75% of the samples taken in the past 0.5s are
* accelerating, the device is a) shaking, or b) free falling 1.84m (h =
* 1/2*g*t^2*3/4).
*
* @author Bob Lee (bob@squareup.com)
* @author Eric Burke (eric@squareup.com)
*/
public class ShakeDetector implements SensorEventListener {
private static final int SHAKE_THRESHOLD = 13;
/** Listens for shakes. */
public interface Listener {
/** Called on the main thread when the device is shaken. */
void onShakeDetected();
}
private final SampleQueue queue = new SampleQueue();
private final Listener listener;
private SensorManager sensorManager;
private Sensor accelerometer;
public ShakeDetector(Listener listener) {
this.listener = listener;
}
/**
* Starts listening for shakes on devices with appropriate hardware.
*
* @return true if the device supports shake detection.
*/
public boolean start(SensorManager sensorManager) {
if (accelerometer != null) {
return true;
}
accelerometer = sensorManager.getDefaultSensor(Sensor.TYPE_ACCELEROMETER);
if (accelerometer != null) {
this.sensorManager = sensorManager;
sensorManager.registerListener(this, accelerometer, SensorManager.SENSOR_DELAY_NORMAL);
}
return accelerometer != null;
}
/**
* Stops listening. Safe to call when already stopped. Ignored on devices without appropriate
* hardware.
*/
public void stop() {
if (accelerometer != null) {
queue.clear();
sensorManager.unregisterListener(this, accelerometer);
sensorManager = null;
accelerometer = null;
}
}
@Override
public void onSensorChanged(SensorEvent event) {
boolean accelerating = isAccelerating(event);
long timestamp = event.timestamp;
queue.add(timestamp, accelerating);
if (queue.isShaking()) {
queue.clear();
listener.onShakeDetected();
}
}
/** Returns true if the device is currently accelerating. */
private boolean isAccelerating(SensorEvent event) {
float ax = event.values[0];
float ay = event.values[1];
float az = event.values[2];
// Instead of comparing magnitude to ACCELERATION_THRESHOLD,
// compare their squares. This is equivalent and doesn't need the
// actual magnitude, which would be computed using (expensive) Math.sqrt().
final double magnitudeSquared = ax * ax + ay * ay + az * az;
return magnitudeSquared > SHAKE_THRESHOLD * SHAKE_THRESHOLD;
}
/** Queue of samples. Keeps a running average. */
static class SampleQueue {
/** Window size in ns. Used to compute the average. */
private static final long MAX_WINDOW_SIZE = 500000000; // 0.5s
private static final long MIN_WINDOW_SIZE = MAX_WINDOW_SIZE >> 1; // 0.25s
/**
* Ensure the queue size never falls below this size, even if the device
* fails to deliver this many events during the time window. The LG Ally
* is one such device.
*/
private static final int MIN_QUEUE_SIZE = 4;
private final SamplePool pool = new SamplePool();
private Sample oldest;
private Sample newest;
private int sampleCount;
private int acceleratingCount;
/**
* Adds a sample.
*
* @param timestamp in nanoseconds of sample
* @param accelerating true if > {@link #SHAKE_THRESHOLD}.
*/
void add(long timestamp, boolean accelerating) {
purge(timestamp - MAX_WINDOW_SIZE);
Sample added = pool.acquire();
added.timestamp = timestamp;
added.accelerating = accelerating;
added.next = null;
if (newest != null) {
newest.next = added;
}
newest = added;
if (oldest == null) {
oldest = added;
}
sampleCount++;
if (accelerating) {
acceleratingCount++;
}
}
/** Removes all samples from this queue. */
void clear() {
while (oldest != null) {
Sample removed = oldest;
oldest = removed.next;
pool.release(removed);
}
newest = null;
sampleCount = 0;
acceleratingCount = 0;
}
/** Purges samples with timestamps older than cutoff. */
void purge(long cutoff) {
while (sampleCount >= MIN_QUEUE_SIZE && oldest != null && cutoff - oldest.timestamp > 0) {
Sample removed = oldest;
if (removed.accelerating) {
acceleratingCount--;
}
sampleCount--;
oldest = removed.next;
if (oldest == null) {
newest = null;
}
pool.release(removed);
}
}
/**
* Returns true if we have enough samples and more than 3/4 of those samples
* are accelerating.
*/
boolean isShaking() {
return newest != null &&
oldest != null &&
newest.timestamp - oldest.timestamp >= MIN_WINDOW_SIZE &&
acceleratingCount >= (sampleCount >> 1) + (sampleCount >> 2);
}
}
/** An accelerometer sample. */
static class Sample {
/** Time sample was taken. */
long timestamp;
/** If acceleration > {@link #SHAKE_THRESHOLD}. */
boolean accelerating;
/** Next sample in the queue or pool. */
Sample next;
}
/** Pools samples. Avoids garbage collection. */
static class SamplePool {
private Sample head;
/** Acquires a sample from the pool. */
Sample acquire() {
Sample acquired = head;
if (acquired == null) {
acquired = new Sample();
} else {
head = acquired.next;
}
return acquired;
}
/** Returns a sample to the pool. */
void release(Sample sample) {
sample.next = head;
head = sample;
}
}
@Override
public void onAccuracyChanged(Sensor sensor, int accuracy) {
}
}

View File

@@ -0,0 +1,518 @@
package org.signal.core.util
import android.content.ContentValues
import android.text.TextUtils
import androidx.annotation.VisibleForTesting
import androidx.sqlite.db.SupportSQLiteDatabase
import org.signal.core.util.logging.Log
import java.lang.Exception
import java.util.LinkedList
import java.util.Locale
import java.util.stream.Collectors
object SqlUtil {
private val TAG = Log.tag(SqlUtil::class.java)
/** The maximum number of arguments (i.e. question marks) allowed in a SQL statement. */
const val MAX_QUERY_ARGS = 999
@JvmField
val COUNT = arrayOf("COUNT(*)")
@JvmStatic
fun tableExists(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT name FROM sqlite_master WHERE type=? AND name=?", arrayOf("table", table)).use { cursor ->
return cursor != null && cursor.moveToNext()
}
}
@JvmStatic
fun getAllTables(db: SupportSQLiteDatabase): List<String> {
val tables: MutableList<String> = LinkedList()
db.query("SELECT name FROM sqlite_master WHERE type=?", arrayOf("table")).use { cursor ->
while (cursor.moveToNext()) {
tables.add(cursor.getString(0))
}
}
return tables
}
/**
* Returns the total number of changes that have been made since the creation of this database connection.
*
* IMPORTANT: Due to how connection pooling is handled in the app, the only way to have this return useful numbers is to call it within a transaction.
*/
fun getTotalChanges(db: SupportSQLiteDatabase): Long {
return db.query("SELECT total_changes()", arrayOf()).readToSingleLong()
}
@JvmStatic
fun getAllTriggers(db: SupportSQLiteDatabase): List<String> {
val tables: MutableList<String> = LinkedList()
db.query("SELECT name FROM sqlite_master WHERE type=?", arrayOf("trigger")).use { cursor ->
while (cursor.moveToNext()) {
tables.add(cursor.getString(0))
}
}
return tables
}
@JvmStatic
fun getNextAutoIncrementId(db: SupportSQLiteDatabase, table: String): Long {
db.query("SELECT * FROM sqlite_sequence WHERE name = ?", arrayOf(table)).use { cursor ->
if (cursor.moveToFirst()) {
val current = cursor.requireLong("seq")
return current + 1
} else if (db.query("SELECT COUNT(*) FROM $table").readToSingleLong(defaultValue = 0) == 0L) {
Log.w(TAG, "No entries exist in $table. Returning 1.")
return 1
} else if (columnExists(db, table, "_id")) {
Log.w(TAG, "There are entries in $table, but we couldn't get the auto-incrementing id? Using the max _id in the table.")
val current = db.query("SELECT MAX(_id) FROM $table").readToSingleLong(defaultValue = 0)
return current + 1
} else {
Log.w(TAG, "No autoincrement _id, non-empty table, no _id column!")
throw IllegalArgumentException("Table must have an auto-incrementing primary key!")
}
}
}
/**
* Given a table, this will return a set of tables that it has a foreign key dependency on.
*/
@JvmStatic
fun getForeignKeyDependencies(db: SupportSQLiteDatabase, table: String): Set<String> {
return db.query("PRAGMA foreign_key_list($table)")
.readToSet { cursor ->
cursor.requireNonNullString("table")
}
}
/**
* Provides a list of all foreign key violations present.
* If a [targetTable] is specified, results will be limited to that table specifically.
* Otherwise, the check will be performed across all tables.
*/
@JvmStatic
@JvmOverloads
fun getForeignKeyViolations(db: SupportSQLiteDatabase, targetTable: String? = null): List<ForeignKeyViolation> {
val tableString = if (targetTable != null) "($targetTable)" else ""
return db.query("PRAGMA foreign_key_check$tableString").readToList { cursor ->
val table = cursor.requireNonNullString("table")
ForeignKeyViolation(
table = table,
violatingRowId = cursor.requireLongOrNull("rowid"),
dependsOnTable = cursor.requireNonNullString("parent"),
column = getForeignKeyViolationColumn(db, table, cursor.requireLong("fkid"))
)
}
}
/**
* For tables that have an autoincrementing primary key, this will reset the key to start back at 1.
* IMPORTANT: This is quite dangerous! Only do this if you're effectively resetting the entire database.
*/
@JvmStatic
fun resetAutoIncrementValue(db: SupportSQLiteDatabase, targetTable: String) {
db.execSQL("DELETE FROM sqlite_sequence WHERE name=?", arrayOf(targetTable))
}
@JvmStatic
fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT COUNT(*) FROM $table", arrayOf()).use { cursor ->
return if (cursor.moveToFirst()) {
cursor.getInt(0) == 0
} else {
true
}
}
}
@JvmStatic
fun columnExists(db: SupportSQLiteDatabase, table: String, column: String): Boolean {
db.query("PRAGMA table_info($table)", arrayOf()).use { cursor ->
val nameColumnIndex = cursor.getColumnIndexOrThrow("name")
while (cursor.moveToNext()) {
val name = cursor.getString(nameColumnIndex)
if (name == column) {
return true
}
}
}
return false
}
@JvmStatic
fun buildArgs(vararg objects: Any?): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic
fun buildArgs(objects: Collection<Any?>): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic
fun buildArgs(argument: Long): Array<String> {
return arrayOf(argument.toString())
}
/**
* Builds a case-insensitive GLOB pattern for fuzzy text queries. Works with all unicode
* characters.
*
* Ex:
* cat -> [cC][aA][tT]
*/
@JvmStatic
fun buildCaseInsensitiveGlobPattern(query: String): String {
if (TextUtils.isEmpty(query)) {
return "*"
}
val pattern = StringBuilder()
var i = 0
val len = query.codePointCount(0, query.length)
while (i < len) {
val point = StringUtil.codePointToString(query.codePointAt(i))
pattern.append("[")
pattern.append(point.lowercase(Locale.getDefault()))
pattern.append(point.uppercase(Locale.getDefault()))
pattern.append(getAccentuatedCharRegex(point.lowercase(Locale.getDefault())))
pattern.append("]")
i++
}
return "*$pattern*"
}
private fun getAccentuatedCharRegex(query: String): String {
return when (query) {
"a" -> "À-Åà-åĀ-ąǍǎǞ-ǡǺ-ǻȀ-ȃȦȧȺɐ-ɒḀḁẚẠ-ặ"
"b" -> "ßƀ-ƅɃɓḂ-ḇ"
"c" -> "çÇĆ-čƆ-ƈȻȼɔḈḉ"
"d" -> "ÐðĎ-đƉ-ƍȡɖɗḊ-ḓ"
"e" -> "È-Ëè-ëĒ-ěƎ-ƐǝȄ-ȇȨȩɆɇɘ-ɞḔ-ḝẸ-ệ"
"f" -> "ƑƒḞḟ"
"g" -> "Ĝ-ģƓǤ-ǧǴǵḠḡ"
"h" -> "Ĥ-ħƕǶȞȟḢ-ḫẖ"
"i" -> "Ì-Ïì-ïĨ-ıƖƗǏǐȈ-ȋɨɪḬ-ḯỈ-ị"
"j" -> "ĴĵǰȷɈɉɟ"
"k" -> "Ķ-ĸƘƙǨǩḰ-ḵ"
"l" -> "Ĺ-łƚȴȽɫ-ɭḶ-ḽ"
"m" -> "Ɯɯ-ɱḾ-ṃ"
"n" -> "ÑñŃ-ŋƝƞǸǹȠȵɲ-ɴṄ-ṋ"
"o" -> "Ò-ÖØò-öøŌ-őƟ-ơǑǒǪ-ǭǾǿȌ-ȏȪ-ȱṌ-ṓỌ-ợ"
"p" -> "ƤƥṔ-ṗ"
"q" -> ""
"r" -> "Ŕ-řƦȐ-ȓɌɍṘ-ṟ"
"s" -> "Ś-šƧƨȘșȿṠ-ṩ"
"t" -> "Ţ-ŧƫ-ƮȚțȾṪ-ṱẗ"
"u" -> "Ù-Üù-üŨ-ųƯ-ƱǓ-ǜȔ-ȗɄṲ-ṻỤ-ự"
"v" -> "ƲɅṼ-ṿ"
"w" -> "ŴŵẀ-ẉẘ"
"x" -> "Ẋ-ẍ"
"y" -> "ÝýÿŶ-ŸƔƳƴȲȳɎɏẎẏỲ-ỹỾỿẙ"
"z" -> "Ź-žƵƶɀẐ-ẕ"
"α" -> "\u0386\u0391\u03AC\u03B1\u1F00-\u1F0F\u1F70\u1F71\u1F80-\u1F8F\u1FB0-\u1FB4\u1FB6-\u1FBC"
"ε" -> "\u0388\u0395\u03AD\u03B5\u1F10-\u1F15\u1F18-\u1F1D\u1F72\u1F73\u1FC8\u1FC9"
"η" -> "\u0389\u0397\u03AE\u03B7\u1F20-\u1F2F\u1F74\u1F75\u1F90-\u1F9F\u1F20-\u1F2F\u1F74\u1F75\u1F90-\u1F9F\u1fc2\u1fc3\u1fc4\u1fc6\u1FC7\u1FCA\u1FCB\u1FCC"
"ι" -> "\u038A\u0390\u0399\u03AA\u03AF\u03B9\u03CA\u1F30-\u1F3F\u1F76\u1F77\u1FD0-\u1FD3\u1FD6-\u1FDB"
"ο" -> "\u038C\u039F\u03BF\u03CC\u1F40-\u1F45\u1F48-\u1F4D\u1F78\u1F79\u1FF8\u1FF9"
"σ" -> "\u03A3\u03C2\u03C3"
"ς" -> "\u03A3\u03C2\u03C3"
"υ" -> "\u038E\u03A5\u03AB\u03C5\u03CB\u03CD\u1F50-\u1F57\u1F59\u1F5B\u1F5D\u1F5F\u1F7A\u1F7B\u1FE0-\u1FE3\u1FE6-\u1FEB"
"ω" -> "\u038F\u03A9\u03C9\u03CE\u1F60-\u1F6F\u1F7C\u1F7D\u1FA0-\u1FAF\u1FF2-\u1FF4\u1FF6\u1FF7\u1FFA-\u1FFC"
else -> ""
}
}
/**
* Returns an updated query and args pairing that will only update rows that would *actually*
* change. In other words, if [SupportSQLiteDatabase.update]
* returns > 0, then you know something *actually* changed.
*/
@JvmStatic
fun buildTrueUpdateQuery(
selection: String,
args: Array<String>,
contentValues: ContentValues
): Query {
val qualifier = StringBuilder()
val valueSet = contentValues.valueSet()
val fullArgs: MutableList<String> = ArrayList(args.size + valueSet.size)
fullArgs.addAll(args)
var i = 0
for ((key, value) in valueSet) {
if (value != null) {
if (value is ByteArray) {
qualifier.append("hex(").append(key).append(") != ? OR ").append(key).append(" IS NULL")
fullArgs.add(Hex.toStringCondensed(value).uppercase(Locale.US))
} else {
qualifier.append(key).append(" != ? OR ").append(key).append(" IS NULL")
fullArgs.add(value.toString())
}
} else {
qualifier.append(key).append(" NOT NULL")
}
if (i != valueSet.size - 1) {
qualifier.append(" OR ")
}
i++
}
return Query("($selection) AND ($qualifier)", fullArgs.toTypedArray())
}
/**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
* Handles breaking it
*/
@JvmOverloads
@JvmStatic
fun buildCollectionQuery(
column: String,
values: Collection<Any?>,
prefix: String = "",
maxSize: Int = MAX_QUERY_ARGS,
collectionOperator: CollectionOperator = CollectionOperator.IN
): List<Query> {
return if (values.isEmpty()) {
emptyList()
} else {
values
.chunked(maxSize)
.map { batch -> buildSingleCollectionQuery(column, batch, prefix, collectionOperator) }
}
}
/**
* A convenient way of making queries that are _equivalent_ to `WHERE [column] IN (?, ?, ..., ?)`
* Under the hood, it uses JSON1 functions which can both be surprisingly faster than normal (?, ?, ?) lists, as well as removes the [MAX_QUERY_ARGS] limit.
* This means chunking isn't necessary for any practical collection length.
*/
@JvmStatic
fun buildFastCollectionQuery(
column: String,
values: Collection<Any?>
): Query {
require(!values.isEmpty()) { "Must have values!" }
return Query("$column IN (SELECT e.value FROM json_each(?) e)", arrayOf(jsonEncode(buildArgs(values))))
}
/**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
*
* Important: Should only be used if you know the number of values is < 1000. Otherwise you risk creating a SQL statement this is too large.
* Prefer [buildCollectionQuery] when possible.
*/
@JvmOverloads
@JvmStatic
fun buildSingleCollectionQuery(
column: String,
values: Collection<Any?>,
prefix: String = "",
collectionOperator: CollectionOperator = CollectionOperator.IN
): Query {
require(!values.isEmpty()) { "Must have values!" }
val query = StringBuilder()
val args = arrayOfNulls<Any>(values.size)
var i = 0
for (value in values) {
query.append("?")
args[i] = value
if (i != values.size - 1) {
query.append(", ")
}
i++
}
return Query("$prefix $column ${collectionOperator.sql} ($query)".trim(), buildArgs(*args))
}
@JvmStatic
fun buildCustomCollectionQuery(query: String, argList: List<Array<String>>): List<Query> {
return buildCustomCollectionQuery(query, argList, MAX_QUERY_ARGS)
}
@JvmStatic
@VisibleForTesting
fun buildCustomCollectionQuery(query: String, argList: List<Array<String>>, maxQueryArgs: Int): List<Query> {
val batchSize: Int = maxQueryArgs / argList[0].size
return ListUtil.chunk(argList, batchSize)
.stream()
.map { argBatch -> buildSingleCustomCollectionQuery(query, argBatch) }
.collect(Collectors.toList())
}
private fun buildSingleCustomCollectionQuery(query: String, argList: List<Array<String>>): Query {
val outputQuery = StringBuilder()
val outputArgs: MutableList<String> = mutableListOf()
var i = 0
val len = argList.size
while (i < len) {
outputQuery.append("(").append(query).append(")")
if (i < len - 1) {
outputQuery.append(" OR ")
}
val args = argList[i]
for (arg in args) {
outputArgs += arg
}
i++
}
return Query(outputQuery.toString(), outputArgs.toTypedArray())
}
@JvmStatic
fun buildQuery(where: String, vararg args: Any): Query {
return Query(where, buildArgs(*args))
}
@JvmStatic
fun appendArg(args: Array<String>, addition: String): Array<String> {
return args.toMutableList().apply {
add(addition)
}.toTypedArray()
}
@JvmStatic
fun appendArgs(args: Array<String>, vararg objects: Any?): Array<String> {
return args + buildArgs(objects)
}
@JvmStatic
fun buildBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, onConflict: String? = null): List<Query> {
return buildBulkInsert(tableName, columns, contentValues, MAX_QUERY_ARGS)
}
@JvmStatic
@VisibleForTesting
fun buildBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, maxQueryArgs: Int, onConflict: String? = null): List<Query> {
val batchSize = maxQueryArgs / columns.size
return contentValues
.chunked(batchSize)
.map { batch: List<ContentValues> -> buildSingleBulkInsert(tableName, columns, batch) }
.toList()
}
fun buildSingleBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, onConflict: String? = null): Query {
val conflictString = onConflict?.let { " OR $onConflict" } ?: ""
val builder = StringBuilder()
builder.append("INSERT$conflictString INTO ").append(tableName).append(" (")
val columnString = columns.joinToString(separator = ", ")
builder.append(columnString)
builder.append(") VALUES ")
val placeholders = contentValues
.map { values ->
columns
.map { column ->
if (values[column] != null) {
if (values[column] is ByteArray) {
"X'${Hex.toStringCondensed(values[column] as ByteArray).uppercase()}'"
} else {
"?"
}
} else {
"null"
}
}
.joinToString(separator = ", ", prefix = "(", postfix = ")")
}
.joinToString(separator = ", ")
builder.append(placeholders)
val query = builder.toString()
val args: MutableList<String> = mutableListOf()
for (values in contentValues) {
for (column in columns) {
val value = values[column]
if (value != null && value !is ByteArray) {
args += value.toString()
}
}
}
return Query(query, args.toTypedArray())
}
/** Helper that gets the specific column for a foreign key violation */
private fun getForeignKeyViolationColumn(db: SupportSQLiteDatabase, table: String, id: Long): String? {
try {
db.query("PRAGMA foreign_key_list($table)").forEach { cursor ->
if (cursor.requireLong("id") == id) {
return cursor.requireString("from")
}
}
} catch (e: Exception) {
Log.w(TAG, "Failed to find violation details for id: $id")
}
return null
}
/** Simple encoding of a string array as a json array */
private fun jsonEncode(strings: Array<String>): String {
return strings.joinToString(prefix = "[", postfix = "]", separator = ",") { "\"$it\"" }
}
class Query(val where: String, val whereArgs: Array<String>) {
infix fun and(other: Query): Query {
return if (where.isNotEmpty() && other.where.isNotEmpty()) {
Query("($where) AND (${other.where})", whereArgs + other.whereArgs)
} else if (where.isNotEmpty()) {
this
} else {
other
}
}
}
data class ForeignKeyViolation(
/** The table that declared the REFERENCES clause. */
val table: String,
/** The rowId of the message in [table] that violates the constraint. Will not be present if the table has now rowId. */
val violatingRowId: Long?,
/** The table that [table] has a dependency on. */
val dependsOnTable: String,
/** The column from [table] that has the constraint. A separate query needs to be made to get this, so it's best-effor. */
val column: String?
)
enum class CollectionOperator(val sql: String) {
IN("IN"),
NOT_IN("NOT IN")
}
}

View File

@@ -0,0 +1,259 @@
package org.signal.core.util
import android.text.SpannableStringBuilder
import okio.ByteString
import okio.ByteString.Companion.toByteString
import okio.utf8Size
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.nio.charset.StandardCharsets
object StringUtil {
private val WHITESPACE: Set<Char> = setOf(
'\u200E', // left-to-right mark
'\u200F', // right-to-left mark
'\u2007', // figure space
'\u200B', // zero-width space
'\u2800' // braille blank
)
/**
* Trims a name string to fit into the byte length requirement.
*
*
* This method treats a surrogate pair and a grapheme cluster a single character
* See examples in tests defined in StringUtilText_trimToFit.
*/
@JvmStatic
fun trimToFit(name: String?, maxByteLength: Int): String {
if (name.isNullOrEmpty()) {
return ""
}
if (name.utf8Size() <= maxByteLength) {
return name
}
try {
ByteArrayOutputStream().use { stream ->
for (graphemeCharacter in CharacterIterable(name)) {
val bytes = graphemeCharacter.toByteArray(StandardCharsets.UTF_8)
if (stream.size() + bytes.size <= maxByteLength) {
stream.write(bytes)
} else {
break
}
}
return stream.toString()
}
} catch (e: IOException) {
throw AssertionError(e)
}
}
/**
* @return A charsequence with no leading or trailing whitespace. Only creates a new charsequence
* if it has to.
*/
@JvmStatic
fun trim(charSequence: CharSequence): CharSequence {
if (charSequence.isEmpty()) {
return charSequence
}
var start = 0
var end = charSequence.length - 1
while (start < charSequence.length && Character.isWhitespace(charSequence[start])) {
start++
}
while (end >= 0 && end > start && Character.isWhitespace(charSequence[end])) {
end--
}
return if (start > 0 || end < charSequence.length - 1) {
charSequence.subSequence(start, end + 1)
} else {
charSequence
}
}
/**
* @return True if the string is empty, or if it contains nothing but whitespace characters.
* Accounts for various unicode whitespace characters.
*/
@JvmStatic
fun isVisuallyEmpty(value: String?): Boolean {
if (value.isNullOrEmpty()) {
return true
}
return indexOfFirstNonEmptyChar(value) == -1
}
/**
* @return String without any leading or trailing whitespace.
* Accounts for various unicode whitespace characters.
*/
@JvmStatic
fun trimToVisualBounds(value: String): String {
val start = indexOfFirstNonEmptyChar(value)
if (start == -1) {
return ""
}
val end = indexOfLastNonEmptyChar(value)
return value.substring(start, end + 1)
}
private fun indexOfFirstNonEmptyChar(value: String): Int {
val length = value.length
for (i in 0 until length) {
if (!isVisuallyEmpty(value[i])) {
return i
}
}
return -1
}
private fun indexOfLastNonEmptyChar(value: String): Int {
for (i in value.length - 1 downTo 0) {
if (!isVisuallyEmpty(value[i])) {
return i
}
}
return -1
}
/**
* @return True if the character is invisible or whitespace. Accounts for various unicode
* whitespace characters.
*/
fun isVisuallyEmpty(c: Char): Boolean {
return Character.isWhitespace(c) || WHITESPACE.contains(c)
}
/**
* @return A string representation of the provided unicode code point.
*/
fun codePointToString(codePoint: Int): String {
return String(Character.toChars(codePoint))
}
/**
* @return True if the text is null or has a length of 0, otherwise false.
*/
@JvmStatic
fun isEmpty(text: String?): Boolean {
return text.isNullOrEmpty()
}
/**
* Trims a [CharSequence] of starting and trailing whitespace. Behavior matches
* [String.trim] to preserve expectations around results.
*/
@JvmStatic
fun trimSequence(text: CharSequence): CharSequence {
var length = text.length
var startIndex = 0
while ((startIndex < length) && (text[startIndex] <= ' ')) {
startIndex++
}
while ((startIndex < length) && (text[length - 1] <= ' ')) {
length--
}
return if ((startIndex > 0 || length < text.length)) text.subSequence(startIndex, length) else text
}
/**
* If the {@param text} exceeds the {@param maxChars} it is trimmed in the middle so that the result is exactly {@param maxChars} long including an added
* ellipsis character.
*
*
* Otherwise the string is returned untouched.
*
*
* When {@param maxChars} is even, one more character is kept from the end of the string than the start.
*/
@JvmStatic
fun abbreviateInMiddle(text: CharSequence?, maxChars: Int): CharSequence? {
if (text == null || text.length <= maxChars) {
return text
}
val start = (maxChars - 1) / 2
val end = (maxChars - 1) - start
return text.subSequence(0, start).toString() + "" + text.subSequence(text.length - end, text.length)
}
/**
* @return The number of graphemes in the provided string.
*/
@JvmStatic
fun getGraphemeCount(text: CharSequence): Int {
val iterator = BreakIteratorCompat.getInstance()
iterator.setText(text)
return iterator.countBreaks()
}
@JvmStatic
fun replace(text: CharSequence, toReplace: Char, replacement: String?): CharSequence {
var updatedText: SpannableStringBuilder? = null
for (i in text.length - 1 downTo 0) {
if (text[i] == toReplace) {
if (updatedText == null) {
updatedText = SpannableStringBuilder.valueOf(text)
}
updatedText!!.replace(i, i + 1, replacement)
}
}
return updatedText ?: text
}
@JvmStatic
fun startsWith(text: CharSequence, substring: CharSequence): Boolean {
if (substring.length > text.length) {
return false
}
for (i in substring.indices) {
if (text[i] != substring[i]) {
return false
}
}
return true
}
@JvmStatic
fun endsWith(text: CharSequence, substring: CharSequence): Boolean {
if (substring.length > text.length) {
return false
}
var textIndex = text.length - 1
var substringIndex = substring.length - 1
while (substringIndex >= 0) {
if (text[textIndex] != substring[substringIndex]) {
return false
}
substringIndex--
textIndex--
}
return true
}
fun String?.toByteString(): ByteString? {
return this?.toByteArray()?.toByteString()
}
}

View File

@@ -0,0 +1,47 @@
package org.signal.core.util
import androidx.sqlite.db.SupportSQLiteProgram
import androidx.sqlite.db.SupportSQLiteQuery
fun SupportSQLiteQuery.toAndroidQuery(): SqlUtil.Query {
val program = CapturingSqliteProgram(this.argCount)
this.bindTo(program)
return SqlUtil.Query(this.sql, program.args())
}
private class CapturingSqliteProgram(count: Int) : SupportSQLiteProgram {
private val args: Array<String?> = arrayOfNulls(count)
fun args(): Array<String> {
return args.filterNotNull().toTypedArray()
}
override fun close() {
}
override fun bindNull(index: Int) {
throw UnsupportedOperationException()
}
override fun bindLong(index: Int, value: Long) {
args[index - 1] = value.toString()
}
override fun bindDouble(index: Int, value: Double) {
args[index - 1] = value.toString()
}
override fun bindString(index: Int, value: String) {
args[index - 1] = value
}
override fun bindBlob(index: Int, value: ByteArray) {
throw UnsupportedOperationException()
}
override fun clearBindings() {
for (i in args.indices) {
args[i] = null
}
}
}

View File

@@ -0,0 +1,115 @@
package org.signal.core.util;
import android.os.Handler;
import android.os.Looper;
import android.os.Process;
import androidx.annotation.NonNull;
import androidx.annotation.VisibleForTesting;
import java.util.concurrent.CountDownLatch;
/**
* Thread related utility functions.
*/
public final class ThreadUtil {
/**
* Default background thread priority.
*/
public static final int PRIORITY_BACKGROUND_THREAD = Process.THREAD_PRIORITY_BACKGROUND;
/**
* Important background thread priority. This is slightly lower priority than the UI thread. Use for critical work that should run as fast as
* possible, but shouldn't block the UI (e.g. message sends)
*/
public static final int PRIORITY_IMPORTANT_BACKGROUND_THREAD = Process.THREAD_PRIORITY_DEFAULT + Process.THREAD_PRIORITY_LESS_FAVORABLE;
/**
* As important as the UI thread. Use for absolutely critical UI blocking tasks/threads. For example fetching data for display in a recyclerview, or
* anything that will block UI.
*/
public static final int PRIORITY_UI_BLOCKING_THREAD = Process.THREAD_PRIORITY_DEFAULT;
private static volatile Handler handler;
@VisibleForTesting
public static volatile boolean enforceAssertions = true;
private ThreadUtil() {}
private static Handler getHandler() {
if (handler == null) {
synchronized (ThreadUtil.class) {
if (handler == null) {
handler = new Handler(Looper.getMainLooper());
}
}
}
return handler;
}
public static boolean isMainThread() {
return Looper.myLooper() == Looper.getMainLooper();
}
public static void assertMainThread() {
if (!isMainThread() && enforceAssertions) {
throw new AssertionError("Must run on main thread.");
}
}
public static void assertNotMainThread() {
if (isMainThread() && enforceAssertions) {
throw new AssertionError("Cannot run on main thread.");
}
}
public static void postToMain(final @NonNull Runnable runnable) {
getHandler().post(runnable);
}
public static void runOnMain(final @NonNull Runnable runnable) {
if (isMainThread()) runnable.run();
else getHandler().post(runnable);
}
public static void runOnMainDelayed(final @NonNull Runnable runnable, long delayMillis) {
getHandler().postDelayed(runnable, delayMillis);
}
public static void cancelRunnableOnMain(@NonNull Runnable runnable) {
getHandler().removeCallbacks(runnable);
}
public static void runOnMainSync(final @NonNull Runnable runnable) {
if (isMainThread()) {
runnable.run();
} else {
final CountDownLatch sync = new CountDownLatch(1);
runOnMain(() -> {
try {
runnable.run();
} finally {
sync.countDown();
}
});
try {
sync.await();
} catch (InterruptedException ie) {
throw new AssertionError(ie);
}
}
}
public static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public static void interruptableSleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException ignored) { }
}
}

View File

@@ -0,0 +1,23 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.content.res.ColorStateList
import android.graphics.PorterDuff
import android.graphics.PorterDuffColorFilter
import androidx.annotation.ColorInt
import androidx.appcompat.widget.Toolbar
import androidx.core.view.MenuItemCompat
import androidx.core.view.forEach
fun Toolbar.setActionItemTint(@ColorInt tint: Int) {
menu.forEach {
MenuItemCompat.setIconTintList(it, ColorStateList.valueOf(tint))
}
navigationIcon?.colorFilter = PorterDuffColorFilter(tint, PorterDuff.Mode.SRC_ATOP)
overflowIcon?.colorFilter = PorterDuffColorFilter(tint, PorterDuff.Mode.SRC_ATOP)
}

View File

@@ -0,0 +1,79 @@
package org.signal.core.util;
import android.content.Context;
import android.content.res.Configuration;
import android.content.res.Resources;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.StringRes;
import java.util.Locale;
/**
* Allows you to detect if a string resource is readable by the user according to their language settings.
*/
public final class TranslationDetection {
private final Resources resourcesLocal;
private final Resources resourcesEn;
private final Configuration configurationLocal;
/**
* @param context Do not pass Application context, as this may not represent the users selected in-app locale.
*/
public TranslationDetection(@NonNull Context context) {
this.resourcesLocal = context.getResources();
this.configurationLocal = resourcesLocal.getConfiguration();
this.resourcesEn = ResourceUtil.getEnglishResources(context);
}
/**
* @param context Can be Application context.
* @param usersLocale Locale of user.
*/
public TranslationDetection(@NonNull Context context, @NonNull Locale usersLocale) {
this.resourcesLocal = ResourceUtil.getResources(context.getApplicationContext(), usersLocale);
this.configurationLocal = resourcesLocal.getConfiguration();
this.resourcesEn = ResourceUtil.getEnglishResources(context);
}
/**
* Returns true if any of these are true:
* - The current locale is English.
* - In a multi-locale capable device, the device supports any English locale in any position.
* - The text for the current locale does not Equal the English.
*/
public boolean textExistsInUsersLanguage(@StringRes int resId) {
if (configSupportsEnglish()) {
return true;
}
String stringEn = resourcesEn.getString(resId);
String stringLocal = resourcesLocal.getString(resId);
return !stringEn.equals(stringLocal);
}
public boolean textExistsInUsersLanguage(@StringRes int... resIds) {
for (int resId : resIds) {
if (!textExistsInUsersLanguage(resId)) {
return false;
}
}
return true;
}
protected boolean configSupportsEnglish() {
if (configurationLocal.locale.getLanguage().equals("en")) {
return true;
}
if (Build.VERSION.SDK_INT >= 24) {
Locale firstMatch = configurationLocal.getLocales().getFirstMatch(new String[]{"en"});
return firstMatch != null && firstMatch.getLanguage().equals("en");
}
return false;
}
}

View File

@@ -0,0 +1,14 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.androidx
import androidx.documentfile.provider.DocumentFile
/**
* Information about a file within the storage. Useful because default [DocumentFile] implementations
* re-query info on each access.
*/
data class DocumentFileInfo(val documentFile: DocumentFile, val name: String, val size: Long)

View File

@@ -0,0 +1,221 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.androidx
import android.content.Context
import android.provider.DocumentsContract
import androidx.documentfile.provider.DocumentFile
import androidx.documentfile.provider.isTreeDocumentFile
import org.signal.core.util.ThreadUtil
import org.signal.core.util.logging.Log
import org.signal.core.util.readToList
import org.signal.core.util.requireLong
import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireString
import java.io.InputStream
import java.io.OutputStream
import kotlin.time.Duration.Companion.seconds
/**
* Collection of helper and optimizized operations for working with [DocumentFile]s.
*/
object DocumentFileUtil {
private val TAG = Log.tag(DocumentFileUtil::class)
private val FILE_PROJECTION = arrayOf(DocumentsContract.Document.COLUMN_DOCUMENT_ID, DocumentsContract.Document.COLUMN_DISPLAY_NAME, DocumentsContract.Document.COLUMN_SIZE)
private const val FILE_SELECTION = "${DocumentsContract.Document.COLUMN_DISPLAY_NAME} = ?"
private const val LIST_FILES_SELECTION = "${DocumentsContract.Document.COLUMN_MIME_TYPE} != ?"
private val LIST_FILES_SELECTION_ARGS = arrayOf(DocumentsContract.Document.MIME_TYPE_DIR)
private const val MAX_STORAGE_ATTEMPTS: Int = 5
private val WAIT_FOR_SCOPED_STORAGE: LongArray = longArrayOf(0, 2.seconds.inWholeMilliseconds, 10.seconds.inWholeMilliseconds, 20.seconds.inWholeMilliseconds, 30.seconds.inWholeMilliseconds)
/** Returns true if the directory represented by the [DocumentFile] has a child with [name]. */
fun DocumentFile.hasFile(name: String): Boolean {
return findFile(name) != null
}
/** Returns the [DocumentFile] for a newly created binary file or null if unable or it already exists */
fun DocumentFile.newFile(name: String): DocumentFile? {
return if (hasFile(name)) {
Log.w(TAG, "Attempt to create new file ($name) but it already exists")
null
} else {
createFile("application/octet-stream", name)
}
}
/** Returns a [DocumentFile] for directory by [name], creating it if it doesn't already exist */
fun DocumentFile.mkdirp(name: String): DocumentFile? {
return findFile(name) ?: createDirectory(name)
}
/** Open an [OutputStream] to the file represented by the [DocumentFile] */
fun DocumentFile.outputStream(context: Context): OutputStream? {
return context.contentResolver.openOutputStream(uri)
}
/** Open an [InputStream] to the file represented by the [DocumentFile] */
@JvmStatic
fun DocumentFile.inputStream(context: Context): InputStream? {
return context.contentResolver.openInputStream(uri)
}
/**
* Will attempt to find the named [file] in the [root] directory and delete it if found.
*
* @return true if found and deleted, false if the file couldn't be deleted, and null if not found
*/
fun DocumentFile.delete(context: Context, file: String): Boolean? {
return findFile(context, file)?.documentFile?.delete()
}
/**
* Will attempt to find the name [fileName] in the [root] directory and return useful information if found using
* a single [Context.getContentResolver] query.
*
* Recommend using this over [DocumentFile.findFile] to prevent excess queries for all files and names.
*
* If direct queries fail to find the file, will fallback to using [DocumentFile.findFile].
*/
fun DocumentFile.findFile(context: Context, fileName: String): DocumentFileInfo? {
val child: List<DocumentFileInfo> = if (isTreeDocumentFile()) {
val childrenUri = DocumentsContract.buildChildDocumentsUriUsingTree(uri, DocumentsContract.getDocumentId(uri))
try {
context
.contentResolver
.query(childrenUri, FILE_PROJECTION, FILE_SELECTION, arrayOf(fileName), null)
?.readToList(predicate = { it.name == fileName }) { cursor ->
val uri = DocumentsContract.buildDocumentUriUsingTree(uri, cursor.requireString(DocumentsContract.Document.COLUMN_DOCUMENT_ID))
val displayName = cursor.requireNonNullString(DocumentsContract.Document.COLUMN_DISPLAY_NAME)
val length = cursor.requireLong(DocumentsContract.Document.COLUMN_SIZE)
DocumentFileInfo(DocumentFile.fromSingleUri(context, uri)!!, displayName, length)
} ?: emptyList()
} catch (e: Exception) {
Log.d(TAG, "Unable to find file directly on ${javaClass.simpleName}, falling back to OS", e)
emptyList()
}
} else {
emptyList()
}
return if (child.size == 1) {
child[0]
} else {
Log.w(TAG, "Did not find single file, found (${child.size}), falling back to OS")
this.findFile(fileName)?.let { DocumentFileInfo(it, it.name!!, it.length()) }
}
}
/**
* List file names and sizes in the [DocumentFile] by directly querying the content resolver ourselves. The system
* implementation makes a separate query for each name and length method call and gets expensive over 1000's of files.
*
* Will fallback to the provided document file's implementation of [DocumentFile.listFiles] if unable to do it directly.
*/
fun DocumentFile.listFiles(context: Context): List<DocumentFileInfo> {
if (isTreeDocumentFile()) {
val childrenUri = DocumentsContract.buildChildDocumentsUriUsingTree(uri, DocumentsContract.getDocumentId(uri))
try {
val results = context
.contentResolver
.query(childrenUri, FILE_PROJECTION, LIST_FILES_SELECTION, LIST_FILES_SELECTION_ARGS, null)
?.use { cursor ->
val results = ArrayList<DocumentFileInfo>(cursor.count)
while (cursor.moveToNext()) {
val uri = DocumentsContract.buildDocumentUriUsingTree(uri, cursor.requireString(DocumentsContract.Document.COLUMN_DOCUMENT_ID))
val displayName = cursor.requireString(DocumentsContract.Document.COLUMN_DISPLAY_NAME)
val length = cursor.requireLong(DocumentsContract.Document.COLUMN_SIZE)
if (displayName != null) {
results.add(DocumentFileInfo(DocumentFile.fromSingleUri(context, uri)!!, displayName, length))
}
}
results
}
if (results != null) {
return results
} else {
Log.w(TAG, "Content provider returned null for query on ${javaClass.simpleName}, falling back to OS")
}
} catch (e: Exception) {
Log.d(TAG, "Unable to query files directly on ${javaClass.simpleName}, falling back to OS", e)
}
}
return listFiles()
.asSequence()
.filter { it.isFile }
.mapNotNull { file -> file.name?.let { DocumentFileInfo(file, it, file.length()) } }
.toList()
}
/**
* System implementation swallows the exception and we are having problems with the rename. This inlines the
* same call and logs the exception. Note this implementation does not update the passed in document file like
* the system implementation. Do not use the provided document file after calling this method.
*
* @return true if rename successful
*/
@JvmStatic
fun DocumentFile.renameTo(context: Context, displayName: String): Boolean {
if (isTreeDocumentFile()) {
Log.d(TAG, "Renaming document directly")
try {
val result = DocumentsContract.renameDocument(context.contentResolver, uri, displayName)
return result != null
} catch (e: Exception) {
Log.w(TAG, "Unable to rename document file, falling back to OS", e)
return renameTo(displayName)
}
} else {
return renameTo(displayName)
}
}
/**
* Historically, we've seen issues with [DocumentFile] operations not working on the first try. This
* retry loop will retry those operations with a varying backoff in attempt to make them work.
*/
@JvmStatic
fun <T> retryDocumentFileOperation(operation: DocumentFileOperation<T>): OperationResult {
var attempts = 0
var operationResult = operation.operation(attempts, MAX_STORAGE_ATTEMPTS)
while (attempts < MAX_STORAGE_ATTEMPTS && !operationResult.isSuccess()) {
ThreadUtil.sleep(WAIT_FOR_SCOPED_STORAGE[attempts])
attempts++
operationResult = operation.operation(attempts, MAX_STORAGE_ATTEMPTS)
}
return operationResult
}
/** Operation to perform in a retry loop via [retryDocumentFileOperation] that could fail based on timing */
fun interface DocumentFileOperation<T> {
fun operation(attempt: Int, maxAttempts: Int): OperationResult
}
/** Result of a single operation in a retry loop via [retryDocumentFileOperation] */
sealed interface OperationResult {
fun isSuccess(): Boolean {
return this is Success
}
/** The operation completed successful */
data class Success(val value: Boolean) : OperationResult
/** Retry the operation */
data object Retry : OperationResult
}
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import android.app.Activity
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
/**
* Variant interface for the BillingApi.
*/
interface BillingApi {
/**
* Listenable stream of billing purchase results. It's up to the user
* to call queryPurchases after subscription.
*/
fun getBillingPurchaseResults(): Flow<BillingPurchaseResult> = emptyFlow()
suspend fun getApiAvailability(): BillingResponseCode = BillingResponseCode.FEATURE_NOT_SUPPORTED
/**
* Queries the Billing API for product pricing. This value should be cached by
* the implementor for 24 hours.
*/
suspend fun queryProduct(): BillingProduct? = null
/**
* Queries the user's current purchases. This enqueues a check and will
* propagate it to the normal callbacks in the api.
*/
suspend fun queryPurchases(): BillingPurchaseResult = BillingPurchaseResult.None
suspend fun launchBillingFlow(activity: Activity) = Unit
/**
* Empty implementation, to be used when play services are available but
* GooglePlayBillingApi is not available.
*/
object Empty : BillingApi
}

View File

@@ -0,0 +1,28 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import android.content.Context
/**
* Provides a dependency model by which the billing api can request different resources.
*/
interface BillingDependencies {
/**
* Application context
*/
val context: Context
/**
* Get the product id from the donations configuration object.
*/
suspend fun getProductId(): String
/**
* Get the base plan id from the donations configuration object.
*/
suspend fun getBasePlanId(): String
}

View File

@@ -0,0 +1,10 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
class BillingError(
val billingResponseCode: Int
) : Exception("$billingResponseCode")

View File

@@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import org.signal.core.util.money.FiatMoney
/**
* Represents a purchasable product from the Google Play Billing API
*/
data class BillingProduct(
val price: FiatMoney
)

View File

@@ -0,0 +1,41 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
/**
* Sealed class hierarchy representing the different success
* and error states of google play billing purchases.
*/
sealed interface BillingPurchaseResult {
data class Success(
val purchaseState: BillingPurchaseState,
val purchaseToken: String,
val isAcknowledged: Boolean,
val purchaseTime: Long,
val isAutoRenewing: Boolean
) : BillingPurchaseResult {
override fun toString(): String {
return """
BillingPurchaseResult {
purchaseState: $purchaseState
purchaseToken: <redacted>
purchaseTime: $purchaseTime
isAcknowledged: $isAcknowledged
isAutoRenewing: $isAutoRenewing
}
""".trimIndent()
}
}
data object UserCancelled : BillingPurchaseResult
data object None : BillingPurchaseResult
data object TryAgainLater : BillingPurchaseResult
data object AlreadySubscribed : BillingPurchaseResult
data object FeatureNotSupported : BillingPurchaseResult
data object GenericError : BillingPurchaseResult
data object NetworkError : BillingPurchaseResult
data object BillingUnavailable : BillingPurchaseResult
}

View File

@@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
/**
* BillingPurchaseState which aligns with the Google Play Billing purchased state.
*/
enum class BillingPurchaseState {
UNSPECIFIED,
PURCHASED,
PENDING
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import org.signal.core.util.logging.Log
enum class BillingResponseCode(val code: Int) {
UNKNOWN(code = Int.MIN_VALUE),
SERVICE_TIMEOUT(code = -3),
FEATURE_NOT_SUPPORTED(code = -2),
SERVICE_DISCONNECTED(code = -1),
OK(code = 0),
USER_CANCELED(code = 1),
SERVICE_UNAVAILABLE(code = 2),
BILLING_UNAVAILABLE(code = 3),
ITEM_UNAVAILABLE(code = 4),
DEVELOPER_ERROR(code = 5),
ERROR(code = 6),
ITEM_ALREADY_OWNED(code = 7),
ITEM_NOT_OWNED(code = 8),
NETWORK_ERROR(code = 12);
val isSuccess: Boolean get() = this == OK
companion object {
private val TAG = Log.tag(BillingResponseCode::class)
fun fromBillingLibraryResponseCode(responseCode: Int): BillingResponseCode {
val code = BillingResponseCode.entries.firstOrNull { responseCode == it.code } ?: UNKNOWN
if (code == UNKNOWN) {
Log.w(TAG, "Unknown response code: $code")
}
return code
}
}
}

View File

@@ -0,0 +1,132 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import android.os.Debug
import android.os.Looper
import androidx.annotation.MainThread
import org.signal.core.util.ThreadUtil
import org.signal.core.util.logging.Log
import java.lang.IllegalStateException
import java.lang.RuntimeException
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
/**
* Attempts to detect ANR's by posting runnables to the main thread and detecting if they've been run within the [anrThreshold].
* If an ANR is detected, it is logged, and the [anrSaver] is called with the series of thread dumps that were taken of the main thread.
*
* The detection of an ANR will cause an internal user to crash.
*/
object AnrDetector {
private val TAG = Log.tag(AnrDetector::class.java)
private var thread: AnrDetectorThread? = null
private val dateFormat = SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS zzz", Locale.US)
@JvmStatic
@MainThread
fun start(anrThreshold: Long = 5.seconds.inWholeMilliseconds, isInternal: () -> Boolean, anrSaver: (String) -> Unit) {
thread?.end()
thread = null
thread = AnrDetectorThread(anrThreshold.milliseconds, isInternal, anrSaver)
thread!!.start()
}
@JvmStatic
@MainThread
fun stop() {
thread?.end()
thread = null
}
private class AnrDetectorThread(
private val anrThreshold: Duration,
private val isInternal: () -> Boolean,
private val anrSaver: (String) -> Unit
) : Thread("signal-anr") {
@Volatile
private var uiRan = false
private val uiRunnable = Runnable {
uiRan = true
}
@Volatile
private var stopped = false
override fun run() {
while (!stopped) {
uiRan = false
ThreadUtil.postToMain(uiRunnable)
val intervalCount = 5
val intervalDuration = anrThreshold.inWholeMilliseconds / intervalCount
if (intervalDuration == 0L) {
throw IllegalStateException("ANR threshold is too small!")
}
val dumps = mutableListOf<String>()
for (i in 1..intervalCount) {
if (stopped) {
Log.i(TAG, "Thread shutting down during intervals.")
return
}
ThreadUtil.sleep(intervalDuration)
if (!uiRan) {
dumps += getMainThreadDump()
} else {
dumps.clear()
}
}
if (!uiRan && !Debug.isDebuggerConnected() && !Debug.waitingForDebugger()) {
Log.w(TAG, "Failed to post to main in ${anrThreshold.inWholeMilliseconds} ms! Likely ANR!")
val dumpString = dumps.joinToString(separator = "\n\n")
Log.w(TAG, "Main thread dumps:\n$dumpString")
ThreadUtil.cancelRunnableOnMain(uiRunnable)
anrSaver(dumpString)
if (isInternal()) {
Log.e(TAG, "Internal user -- crashing!")
throw SignalAnrException()
}
}
dumps.clear()
}
Log.i(TAG, "Thread shutting down.")
}
fun end() {
stopped = true
}
private fun getMainThreadDump(): String {
val dump: Map<Thread, Array<StackTraceElement>> = Thread.getAllStackTraces()
val mainThread = Looper.getMainLooper().thread
val date = dateFormat.format(Date())
val dumpString = dump[mainThread]?.joinToString(separator = "\n") ?: "Not available."
return "--- $date:\n$dumpString"
}
}
private class SignalAnrException : RuntimeException()
}

View File

@@ -0,0 +1,146 @@
package org.signal.core.util.concurrent
import android.os.Handler
import org.signal.core.util.logging.Log
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
/**
* A class that polls active threads at a set interval and logs when multiple threads are BLOCKED.
*/
class DeadlockDetector(private val handler: Handler, private val pollingInterval: Long) {
private var running = false
private val previouslyBlocked: MutableSet<Long> = mutableSetOf()
private val waitingStates: Set<Thread.State> = setOf(Thread.State.WAITING, Thread.State.TIMED_WAITING)
@Volatile
var lastThreadDump: Map<Thread, Array<StackTraceElement>>? = null
@Volatile
var lastThreadDumpTime: Long = -1
fun start() {
Log.d(TAG, "Beginning deadlock monitoring.")
running = true
handler.postDelayed(this::poll, pollingInterval)
}
fun stop() {
Log.d(TAG, "Ending deadlock monitoring.")
running = false
handler.removeCallbacksAndMessages(null)
}
private fun poll() {
val time: Long = System.currentTimeMillis()
val threads: Map<Thread, Array<StackTraceElement>> = Thread.getAllStackTraces()
val blocked: Map<Thread, Array<StackTraceElement>> = threads
.filter { entry ->
val thread: Thread = entry.key
val stack: Array<StackTraceElement> = entry.value
thread.state == Thread.State.BLOCKED || (thread.state.isWaiting() && stack.hasPotentialLock())
}
.filter { entry -> !BLOCK_BLOCKLIST.contains(entry.key.name) }
val blockedIds: Set<Long> = blocked.keys.map(Thread::getId).toSet()
val stillBlocked: Set<Long> = blockedIds.intersect(previouslyBlocked)
if (blocked.size > 1) {
Log.w(TAG, buildLogString("Found multiple blocked threads! Possible deadlock.", blocked))
lastThreadDump = threads
lastThreadDumpTime = time
} else if (stillBlocked.isNotEmpty()) {
val stillBlockedMap: Map<Thread, Array<StackTraceElement>> = stillBlocked
.map { blockedId ->
val key: Thread = blocked.keys.first { it.id == blockedId }
val value: Array<StackTraceElement> = blocked[key]!!
Pair(key, value)
}
.toMap()
Log.w(TAG, buildLogString("Found a long block! Blocked for at least $pollingInterval ms.", stillBlockedMap))
lastThreadDump = threads
lastThreadDumpTime = time
}
val fullExecutors: List<ExecutorInfo> = CHECK_FULLNESS_EXECUTORS.filter { isExecutorFull(it.executor) }
if (fullExecutors.isNotEmpty()) {
fullExecutors.forEach { executorInfo ->
val fullMap: Map<Thread, Array<StackTraceElement>> = threads
.filter { it.key.name.startsWith(executorInfo.namePrefix) }
.toMap()
val executor: ThreadPoolExecutor = executorInfo.executor as ThreadPoolExecutor
Log.w(TAG, buildLogString("Found a full executor! ${executor.activeCount}/${executor.maximumPoolSize} threads active with ${executor.queue.size} tasks queued.", fullMap))
}
lastThreadDump = threads
lastThreadDumpTime = time
}
previouslyBlocked.clear()
previouslyBlocked.addAll(blockedIds)
if (running) {
handler.postDelayed(this::poll, pollingInterval)
}
}
private data class ExecutorInfo(
val executor: ExecutorService,
val namePrefix: String
)
private fun Thread.State.isWaiting(): Boolean {
return waitingStates.contains(this)
}
private fun Array<StackTraceElement>.hasPotentialLock(): Boolean {
return any {
it.methodName.startsWith("lock") || (it.methodName.startsWith("waitForConnection") && !it.className.contains("IncomingMessageObserver"))
}
}
companion object {
private val TAG = Log.tag(DeadlockDetector::class.java)
private val CHECK_FULLNESS_EXECUTORS: Set<ExecutorInfo> = setOf(
ExecutorInfo(SignalExecutors.BOUNDED, "signal-bounded-"),
ExecutorInfo(SignalExecutors.BOUNDED_IO, "signal-io-bounded")
)
private const val CONCERNING_QUEUE_THRESHOLD = 4
private val BLOCK_BLOCKLIST = setOf("HeapTaskDaemon")
private fun buildLogString(description: String, blocked: Map<Thread, Array<StackTraceElement>>): String {
val stringBuilder = StringBuilder()
stringBuilder.append(description).append("\n")
for (entry in blocked) {
stringBuilder.append("-- [${entry.key.id}] ${entry.key.name} | ${entry.key.state}\n")
val stackTrace: Array<StackTraceElement> = entry.value
for (element in stackTrace) {
stringBuilder.append("$element\n")
}
stringBuilder.append("\n")
}
return stringBuilder.toString()
}
private fun isExecutorFull(executor: ExecutorService): Boolean {
return if (executor is ThreadPoolExecutor) {
executor.queue.size > CONCERNING_QUEUE_THRESHOLD
} else {
false
}
}
}
}

View File

@@ -0,0 +1,87 @@
package org.signal.core.util.concurrent;
import androidx.annotation.NonNull;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
/**
* A serial executor that will order pending tasks by a specified priority, and will only keep a single task of a given priority, preferring the latest.
*
* So imagine a world where the following tasks were all enqueued (meaning they're all waiting to be executed):
*
* execute(0, runnableA);
* execute(3, runnableC1);
* execute(3, runnableC2);
* execute(2, runnableB);
*
* You'd expect the execution order to be:
* - runnableC2
* - runnableB
* - runnableA
*
* (We order by priority, and C1 was replaced by C2)
*/
public final class LatestPrioritizedSerialExecutor {
private final Queue<PriorityRunnable> tasks;
private final Executor executor;
private Runnable active;
public LatestPrioritizedSerialExecutor(@NonNull Executor executor) {
this.executor = executor;
this.tasks = new PriorityQueue<>();
}
/**
* Execute with a priority. Higher priorities are executed first.
*/
public synchronized void execute(int priority, @NonNull Runnable r) {
Iterator<PriorityRunnable> iterator = tasks.iterator();
while (iterator.hasNext()) {
if (iterator.next().getPriority() == priority) {
iterator.remove();
}
}
tasks.offer(new PriorityRunnable(priority) {
@Override
public void run() {
try {
r.run();
} finally {
scheduleNext();
}
}
});
if (active == null) {
scheduleNext();
}
}
private synchronized void scheduleNext() {
if ((active = tasks.poll()) != null) {
executor.execute(active);
}
}
private abstract static class PriorityRunnable implements Runnable, Comparable<PriorityRunnable> {
private final int priority;
public PriorityRunnable(int priority) {
this.priority = priority;
}
public int getPriority() {
return priority;
}
@Override
public final int compareTo(PriorityRunnable other) {
return other.getPriority() - this.getPriority();
}
}
}

View File

@@ -0,0 +1,23 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Observer
import io.reactivex.rxjava3.subjects.BehaviorSubject
/**
* An Observer that provides instant access to the latest emitted value.
* Basically a read-only version of [BehaviorSubject].
*/
class LatestValueObservable<T : Any>(private val subject: BehaviorSubject<T>) : Observable<T>() {
val value: T?
get() = subject.value
override fun subscribeActual(observer: Observer<in T>) {
subject.subscribe(observer)
}
}

View File

@@ -0,0 +1,48 @@
package org.signal.core.util.concurrent
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleOwner
import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.disposables.Disposable
/**
* A lifecycle-aware [Disposable] that, after being bound to a lifecycle, will automatically dispose all contained disposables at the proper time.
*/
class LifecycleDisposable : DefaultLifecycleObserver {
val disposables: CompositeDisposable = CompositeDisposable()
fun bindTo(lifecycleOwner: LifecycleOwner): LifecycleDisposable {
return bindTo(lifecycleOwner.lifecycle)
}
fun bindTo(lifecycle: Lifecycle): LifecycleDisposable {
lifecycle.addObserver(this)
return this
}
fun add(disposable: Disposable): LifecycleDisposable {
disposables.add(disposable)
return this
}
fun addAll(vararg disposable: Disposable): LifecycleDisposable {
disposables.addAll(*disposable)
return this
}
fun clear() {
disposables.clear()
}
override fun onDestroy(owner: LifecycleOwner) {
owner.lifecycle.removeObserver(this)
disposables.clear()
}
operator fun plusAssign(disposable: Disposable) {
add(disposable)
}
}
fun Disposable.addTo(lifecycleDisposable: LifecycleDisposable): Disposable = apply { lifecycleDisposable.add(this) }

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Semaphore
object LimitedWorker {
/**
* Call [worker] on a thread from [executor] for each element in [input] using only up to [maxThreads] concurrently.
*
* This method will block until all work is completed. There is no guarantee that the same threads
* will be used but that only up to [maxThreads] will be actively doing work.
*/
@JvmStatic
fun <T> execute(executor: ExecutorService, maxThreads: Int, input: Collection<T>, worker: (T) -> Unit) {
val doneWorkLatch = CountDownLatch(input.size)
val semaphore = Semaphore(maxThreads)
for (work in input) {
semaphore.acquire()
executor.execute {
worker(work)
semaphore.release()
doneWorkLatch.countDown()
}
}
doneWorkLatch.await()
}
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Maybe
import io.reactivex.rxjava3.exceptions.Exceptions
import io.reactivex.rxjava3.plugins.RxJavaPlugins
/**
* Kotlin 1.8 started respecting RxJava nullability annotations but RxJava has some oddities where it breaks those rules.
* This essentially re-implements [Maybe.fromCallable] with an emitter so we don't have to do it everywhere ourselves.
*/
object MaybeCompat {
fun <T : Any> fromCallable(callable: () -> T?): Maybe<T> {
return Maybe.create { emitter ->
val result = try {
callable()
} catch (e: Throwable) {
Exceptions.throwIfFatal(e)
if (!emitter.isDisposed) {
emitter.onError(e)
} else {
RxJavaPlugins.onError(e)
}
return@create
}
if (!emitter.isDisposed) {
if (result == null) {
emitter.onComplete()
} else {
emitter.onSuccess(result)
}
}
}
}
}

View File

@@ -0,0 +1,75 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("RxExtensions")
package org.signal.core.util.concurrent
import androidx.lifecycle.LifecycleOwner
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.kotlin.addTo
import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.subjects.Subject
fun <T : Any> Flowable<T>.observe(viewLifecycleOwner: LifecycleOwner, onNext: (T) -> Unit) {
val lifecycleDisposable = LifecycleDisposable()
lifecycleDisposable.bindTo(viewLifecycleOwner)
lifecycleDisposable += subscribeBy(onNext = onNext)
}
fun Completable.observe(viewLifecycleOwner: LifecycleOwner, onComplete: () -> Unit) {
val lifecycleDisposable = LifecycleDisposable()
lifecycleDisposable.bindTo(viewLifecycleOwner)
lifecycleDisposable += subscribeBy(onComplete = onComplete)
}
fun <S : Subject<T>, T : Any> Observable<T>.subscribeWithSubject(
subject: S,
disposables: CompositeDisposable
): S {
subscribeBy(
onNext = subject::onNext,
onError = subject::onError,
onComplete = subject::onComplete
).addTo(disposables)
return subject
}
fun <S : Subject<T>, T : Any> Single<T>.subscribeWithSubject(
subject: S,
disposables: CompositeDisposable
): S {
subscribeBy(
onSuccess = {
subject.onNext(it)
subject.onComplete()
},
onError = subject::onError
).addTo(disposables)
return subject
}
/**
* Skips the first item emitted from the flowable, but only if it matches the provided [predicate].
*/
fun <T : Any> Flowable<T>.skipFirstIf(predicate: (T) -> Boolean): Flowable<T> {
return this
.scan(Pair<Boolean, T?>(false, null)) { acc, item ->
val firstItemInList = !acc.first
if (firstItemInList && predicate(item)) {
true to null
} else {
true to item
}
}
.filter { it.second != null }
.map { it.second!! }
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
/**
* [Dispatchers] wrapper to allow tests to inject test dispatchers.
*/
object SignalDispatchers {
private var dispatcherProvider: DispatcherProvider = DefaultDispatcherProvider
fun setDispatcherProvider(dispatcherProvider: DispatcherProvider = DefaultDispatcherProvider) {
this.dispatcherProvider = dispatcherProvider
}
val Main get() = dispatcherProvider.main
val IO get() = dispatcherProvider.io
val Default get() = dispatcherProvider.default
val Unconfined get() = dispatcherProvider.unconfined
interface DispatcherProvider {
val main: CoroutineDispatcher
val io: CoroutineDispatcher
val default: CoroutineDispatcher
val unconfined: CoroutineDispatcher
}
private object DefaultDispatcherProvider : DispatcherProvider {
override val main: CoroutineDispatcher = Dispatchers.Main
override val io: CoroutineDispatcher = Dispatchers.IO
override val default: CoroutineDispatcher = Dispatchers.Default
override val unconfined: CoroutineDispatcher = Dispatchers.Unconfined
}
}

View File

@@ -0,0 +1,105 @@
package org.signal.core.util.concurrent;
import android.os.HandlerThread;
import android.os.Process;
import androidx.annotation.NonNull;
import org.signal.core.util.ThreadUtil;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public final class SignalExecutors {
public static final ExecutorService UNBOUNDED = Executors.newCachedThreadPool(new NumberedThreadFactory("signal-unbounded", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService BOUNDED = Executors.newFixedThreadPool(4, new NumberedThreadFactory("signal-bounded", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService SERIAL = Executors.newSingleThreadExecutor(new NumberedThreadFactory("signal-serial", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService BOUNDED_IO = newCachedBoundedExecutor("signal-io-bounded", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 32, 30);
private SignalExecutors() {}
public static ExecutorService newCachedSingleThreadExecutor(final String name, int priority) {
ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 15, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), r -> new Thread(r, name) {
@Override public void run() {
Process.setThreadPriority(priority);
super.run();
}
});
executor.allowCoreThreadTimeOut(true);
return executor;
}
/**
* ThreadPoolExecutor will only create a new thread if the provided queue returns false from
* offer(). That means if you give it an unbounded queue, it'll only ever create 1 thread, no
* matter how long the queue gets.
* <p>
* But if you bound the queue and submit more runnables than there are threads, your task is
* rejected and throws an exception.
* <p>
* So we make a queue that will always return false if it's non-empty to ensure new threads get
* created. Then, if a task gets rejected, we simply add it to the queue.
*/
public static ExecutorService newCachedBoundedExecutor(final String name, int priority, int minThreads, int maxThreads, int timeoutSeconds) {
ThreadPoolExecutor threadPool = new ThreadPoolExecutor(minThreads,
maxThreads,
timeoutSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>() {
@Override
public boolean offer(Runnable runnable) {
if (isEmpty()) {
return super.offer(runnable);
} else {
return false;
}
}
}, new NumberedThreadFactory(name, priority));
threadPool.setRejectedExecutionHandler((runnable, executor) -> {
try {
executor.getQueue().put(runnable);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
return threadPool;
}
public static HandlerThread getAndStartHandlerThread(@NonNull String name, int priority) {
HandlerThread handlerThread = new HandlerThread(name, priority);
handlerThread.start();
return handlerThread;
}
public static class NumberedThreadFactory implements ThreadFactory {
private final int priority;
private final String baseName;
private final AtomicInteger counter;
public NumberedThreadFactory(@NonNull String baseName, int priority) {
this.priority = priority;
this.baseName = baseName;
this.counter = new AtomicInteger();
}
@Override
public Thread newThread(@NonNull Runnable r) {
return new Thread(r, baseName + "-" + counter.getAndIncrement()) {
@Override
public void run() {
Process.setThreadPriority(priority);
super.run();
}
};
}
}
}

View File

@@ -0,0 +1,103 @@
package org.signal.core.util.concurrent;
import android.os.AsyncTask;
import androidx.annotation.NonNull;
import androidx.lifecycle.Lifecycle;
import androidx.lifecycle.LifecycleEventObserver;
import androidx.lifecycle.LifecycleOwner;
import org.signal.core.util.ThreadUtil;
import org.signal.core.util.concurrent.SignalExecutors;
import java.util.concurrent.Executor;
import io.reactivex.rxjava3.observers.DefaultObserver;
public class SimpleTask {
/**
* Runs a task in the background and passes the result of the computation to a task that is run
* on the main thread. Will only invoke the {@code foregroundTask} if the provided {@link Lifecycle}
* is in a valid (i.e. visible) state at that time. In this way, it is very similar to
* {@link AsyncTask}, but is safe in that you can guarantee your task won't be called when your
* view is in an invalid state.
*/
public static <E> void run(@NonNull Lifecycle lifecycle, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
if (!isValid(lifecycle)) {
return;
}
SignalExecutors.BOUNDED.execute(() -> {
final E result = backgroundTask.run();
if (isValid(lifecycle)) {
ThreadUtil.runOnMain(() -> {
if (isValid(lifecycle)) {
foregroundTask.run(result);
}
});
}
});
}
/**
* Runs a task in the background and passes the result of the computation to a task that is run
* on the main thread. Will only invoke the {@code foregroundTask} if the provided {@link Lifecycle}
* is or enters in the future a valid (i.e. visible) state. In this way, it is very similar to
* {@link AsyncTask}, but is safe in that you can guarantee your task won't be called when your
* view is in an invalid state.
*/
public static <E> void runWhenValid(@NonNull Lifecycle lifecycle, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
lifecycle.addObserver(new LifecycleEventObserver() {
@Override public void onStateChanged(@NonNull LifecycleOwner lifecycleOwner, @NonNull Lifecycle.Event event) {
if (isValid(lifecycle)) {
lifecycle.removeObserver(this);
SignalExecutors.BOUNDED.execute(() -> {
final E result = backgroundTask.run();
if (isValid(lifecycle)) {
ThreadUtil.runOnMain(() -> {
if (isValid(lifecycle)) {
foregroundTask.run(result);
}
});
}
});
}
}
});
}
/**
* Runs a task in the background and passes the result of the computation to a task that is run on
* the main thread. Essentially {@link AsyncTask}, but lambda-compatible.
*/
public static <E> void run(@NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
run(SignalExecutors.BOUNDED, backgroundTask, foregroundTask);
}
/**
* Runs a task on the specified {@link Executor} and passes the result of the computation to a
* task that is run on the main thread. Essentially {@link AsyncTask}, but lambda-compatible.
*/
public static <E> void run(@NonNull Executor executor, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
executor.execute(() -> {
final E result = backgroundTask.run();
ThreadUtil.runOnMain(() -> foregroundTask.run(result));
});
}
private static boolean isValid(@NonNull Lifecycle lifecycle) {
return lifecycle.getCurrentState().isAtLeast(Lifecycle.State.CREATED);
}
public interface BackgroundTask<E> {
E run();
}
public interface ForegroundTask<E> {
void run(E result);
}
}

View File

@@ -0,0 +1,60 @@
package org.signal.core.util.logging
import android.annotation.SuppressLint
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executor
import java.util.concurrent.Executors
@SuppressLint("LogNotSignal")
object AndroidLogger : Log.Logger() {
private val serialExecutor: Executor = Executors.newSingleThreadExecutor { Thread(it, "signal-logcat") }
override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.v(tag, message.scrub(), t)
}
}
override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.d(tag, message.scrub(), t)
}
}
override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.i(tag, message.scrub(), t)
}
}
override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.w(tag, message.scrub(), t)
}
}
override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.e(tag, message.scrub(), t)
}
}
override fun flush() {
val latch = CountDownLatch(1)
serialExecutor.execute {
latch.countDown()
}
try {
latch.await()
} catch (e: InterruptedException) {
android.util.Log.w("AndroidLogger", "Interrupted while waiting for flush()", e)
}
}
private fun String?.scrub(): String? {
return this?.let { Scrubber.scrub(it).toString() }
}
}

View File

@@ -0,0 +1,106 @@
package org.signal.core.util.money;
import androidx.annotation.NonNull;
import java.math.BigDecimal;
import java.text.NumberFormat;
import java.util.Currency;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
public class FiatMoney {
private static final Set<String> SPECIAL_CASE_MULTIPLICANDS = new HashSet<>() {{
add("UGX");
add("ISK");
}};
private final BigDecimal amount;
private final Currency currency;
private final long timestamp;
public FiatMoney(@NonNull BigDecimal amount, @NonNull Currency currency) {
this(amount, currency, 0);
}
public FiatMoney(@NonNull BigDecimal amount, @NonNull Currency currency, long timestamp) {
this.amount = amount;
this.currency = currency;
this.timestamp = timestamp;
}
public @NonNull BigDecimal getAmount() {
return amount;
}
public @NonNull Currency getCurrency() {
return currency;
}
public long getTimestamp() {
return timestamp;
}
/**
* @return amount, rounded to the default fractional amount.
*/
public @NonNull String getDefaultPrecisionString() {
return getDefaultPrecisionString(Locale.getDefault());
}
/**
* @return amount, rounded to the default fractional amount.
*/
public @NonNull String getDefaultPrecisionString(@NonNull Locale locale) {
NumberFormat formatter = NumberFormat.getInstance(locale);
formatter.setMinimumFractionDigits(currency.getDefaultFractionDigits());
formatter.setGroupingUsed(false);
return formatter.format(amount);
}
/**
* Note: This special cases SPECIAL_CASE_MULTIPLICANDS members to act as two decimal.
*
* @return amount, in smallest possible units (cents, yen, etc.)
*/
public @NonNull String getMinimumUnitPrecisionString() {
NumberFormat formatter = NumberFormat.getInstance(Locale.US);
formatter.setMaximumFractionDigits(0);
formatter.setGroupingUsed(false);
String currencyCode = currency.getCurrencyCode();
BigDecimal multiplicand = BigDecimal.TEN.pow(SPECIAL_CASE_MULTIPLICANDS.contains(currencyCode) ? 2 : currency.getDefaultFractionDigits());
return formatter.format(amount.multiply(multiplicand));
}
/**
* Transforms the given currency / amount pair from a signal network amount to a FiatMoney, accounting for the special
* cased multiplicands for ISK and UGX
*/
public static @NonNull FiatMoney fromSignalNetworkAmount(@NonNull BigDecimal amount, @NonNull Currency currency) {
String currencyCode = currency.getCurrencyCode();
int shift = SPECIAL_CASE_MULTIPLICANDS.contains(currencyCode) ? 2: currency.getDefaultFractionDigits();
BigDecimal shiftedAmount = amount.movePointLeft(shift);
return new FiatMoney(shiftedAmount, currency);
}
public static boolean equals(FiatMoney left, FiatMoney right) {
return Objects.equals(left.amount, right.amount) &&
Objects.equals(left.currency, right.currency) &&
Objects.equals(left.timestamp, right.timestamp);
}
@Override
public String toString() {
return "FiatMoney{" +
"amount=" + amount +
", currency=" + currency +
", timestamp=" + timestamp +
'}';
}
}

View File

@@ -0,0 +1,24 @@
package org.signal.core.util.money
import java.util.Currency
/**
* Utility methods for java.util.Currency
*
* This is prefixed with "Platform" as there are several different Currency classes
* available in the app, and this utility class is specifically for dealing with
* java.util.Currency
*/
object PlatformCurrencyUtil {
val USD: Currency = Currency.getInstance("USD")
/**
* Note: Adding this as an extension method of Currency causes some confusion in
* AndroidStudio due to a separate Currency class from the AndroidSDK having
* an extension method of the same signature.
*/
fun getAvailableCurrencyCodes(): Set<String> {
return Currency.getAvailableCurrencies().map { it.currencyCode }.toSet()
}
}

View File

@@ -0,0 +1,236 @@
package org.signal.core.util.tracing;
import android.os.SystemClock;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import okio.ByteString;
/**
* A class to create Perfetto-compatible traces. Currently keeps the entire trace in memory to
* avoid weirdness with synchronizing to disk.
* <p>
* Some general info on how the Perfetto format works:
* - The file format is just a Trace proto (see Trace.proto)
* - The Trace proto is just a series of TracePackets
* - TracePackets can describe:
* - Threads
* - Start of a method
* - End of a method
* - (And a bunch of other stuff that's not relevant to use at this point)
* <p>
* We keep a circular buffer of TracePackets for method calls, and we keep a separate list of
* TracePackets for threads so we don't lose any of those.
* <p>
* Serializing is just a matter of throwing all the TracePackets we have into a proto.
* <p>
* Note: This class aims to be largely-thread-safe, but prioritizes speed and memory efficiency
* above all else. These methods are going to be called very quickly from every thread imaginable,
* and we want to create as little overhead as possible. The idea being that it's ok if we don't,
* for example, keep a perfect circular buffer size if it allows us to reduce overhead. The only
* cost of screwing up would be dropping a trace packet or something, which, while sad, won't affect
* how the app functions
*/
public final class Tracer {
public static final class TrackId {
public static final long DB_LOCK = -8675309;
private static final String DB_LOCK_NAME = "Database Lock";
}
private static final Tracer INSTANCE = new Tracer();
private static final int TRUSTED_SEQUENCE_ID = 1;
private static final byte[] SYNCHRONIZATION_MARKER = toByteArray(UUID.fromString("82477a76-b28d-42ba-81dc-33326d57a079"));
private static final long SYNCHRONIZATION_INTERVAL = TimeUnit.SECONDS.toNanos(3);
private final Clock clock;
private final Map<Long, TracePacket> threadPackets;
private final Queue<TracePacket> eventPackets;
private final AtomicInteger eventCount;
private long lastSyncTime;
private long maxBufferSize;
private Tracer() {
this.clock = SystemClock::elapsedRealtimeNanos;
this.threadPackets = new ConcurrentHashMap<>();
this.eventPackets = new ConcurrentLinkedQueue<>();
this.eventCount = new AtomicInteger(0);
this.maxBufferSize = 3_500;
}
public static @NonNull Tracer getInstance() {
return INSTANCE;
}
public void setMaxBufferSize(long maxBufferSize) {
this.maxBufferSize = maxBufferSize;
}
public void start(@NonNull String methodName) {
start(methodName, Thread.currentThread().getId(), null);
}
public void start(@NonNull String methodName, long trackId) {
start(methodName, trackId, null);
}
public void start(@NonNull String methodName, @NonNull String key, @Nullable String value) {
start(methodName, Thread.currentThread().getId(), key, value);
}
public void start(@NonNull String methodName, long trackId, @NonNull String key, @Nullable String value) {
start(methodName, trackId, Collections.singletonMap(key, value));
}
public void start(@NonNull String methodName, @Nullable Map<String, String> values) {
start(methodName, Thread.currentThread().getId(), values);
}
public void start(@NonNull String methodName, long trackId, @Nullable Map<String, String> values) {
long time = clock.getTimeNanos();
if (time - lastSyncTime > SYNCHRONIZATION_INTERVAL) {
addPacket(forSynchronization(time));
lastSyncTime = time;
}
if (!threadPackets.containsKey(trackId)) {
threadPackets.put(trackId, forTrackId(trackId));
}
addPacket(forMethodStart(methodName, time, trackId, values));
}
public void end(@NonNull String methodName) {
addPacket(forMethodEnd(methodName, clock.getTimeNanos(), Thread.currentThread().getId()));
}
public void end(@NonNull String methodName, long trackId) {
addPacket(forMethodEnd(methodName, clock.getTimeNanos(), trackId));
}
public @NonNull byte[] serialize() {
List<TracePacket> packets = new ArrayList<>();
packets.addAll(threadPackets.values());
packets.addAll(eventPackets);
packets.add(forSynchronization(clock.getTimeNanos()));
return new Trace.Builder().packet(packets).build().encode();
}
/**
* Attempts to add a packet to our list while keeping the size of our circular buffer in-check.
* The tracking of the event count is not perfectly thread-safe, but doing it in a thread-safe
* way would likely involve adding a lock, which we really don't want to do, since it'll add
* unnecessary overhead.
* <p>
* Note that we keep track of the event count separately because
* {@link ConcurrentLinkedQueue#size()} is NOT a constant-time operation.
*/
private void addPacket(@NonNull TracePacket packet) {
eventPackets.add(packet);
int size = eventCount.incrementAndGet();
for (int i = size; i > maxBufferSize; i--) {
eventPackets.poll();
eventCount.decrementAndGet();
}
}
private TracePacket forTrackId(long id) {
if (id == TrackId.DB_LOCK) {
return forTrack(id, TrackId.DB_LOCK_NAME);
} else {
Thread currentThread = Thread.currentThread();
return forTrack(currentThread.getId(), currentThread.getName());
}
}
private static TracePacket forTrack(long id, String name) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.track_descriptor(new TrackDescriptor.Builder()
.uuid(id)
.name(name).build())
.build();
}
private static TracePacket forMethodStart(@NonNull String name, long time, long threadId, @Nullable Map<String, String> values) {
TrackEvent.Builder event = new TrackEvent.Builder()
.track_uuid(threadId)
.name(name)
.type(TrackEvent.Type.TYPE_SLICE_BEGIN);
List<DebugAnnotation> debugAnnotations = new LinkedList<>();
if (values != null) {
for (Map.Entry<String, String> entry : values.entrySet()) {
debugAnnotations.add(debugAnnotation(entry.getKey(), entry.getValue()));
}
}
event.debug_annotations(debugAnnotations);
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.track_event(event.build())
.build();
}
private static DebugAnnotation debugAnnotation(@NonNull String key, @Nullable String value) {
return new DebugAnnotation.Builder()
.name(key)
.string_value(value != null ? value : "")
.build();
}
private static TracePacket forMethodEnd(@NonNull String name, long time, long threadId) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.track_event(new TrackEvent.Builder()
.track_uuid(threadId)
.name(name)
.type(TrackEvent.Type.TYPE_SLICE_END)
.build())
.build();
}
private static TracePacket forSynchronization(long time) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.synchronization_marker(ByteString.of(SYNCHRONIZATION_MARKER))
.build();
}
public static byte[] toByteArray(UUID uuid) {
ByteBuffer buffer = ByteBuffer.wrap(new byte[16]);
buffer.putLong(uuid.getMostSignificantBits());
buffer.putLong(uuid.getLeastSignificantBits());
return buffer.array();
}
private interface Clock {
long getTimeNanos();
}
}

View File

@@ -0,0 +1,151 @@
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto2";
package signal;
option java_package = "org.signal.core.util.tracing";
option java_outer_classname = "TraceProtos";
/*
* Minimal interface needed to work with Perfetto.
*
* https://cs.android.com/android/platform/superproject/+/master:external/perfetto/protos/perfetto/trace/trace.proto
*/
message Trace {
repeated TracePacket packet = 1;
}
message TracePacket {
optional uint64 timestamp = 8;
optional uint32 timestamp_clock_id = 58;
oneof data {
TrackEvent track_event = 11;
TrackDescriptor track_descriptor = 60;
bytes synchronization_marker = 36;
}
oneof optional_trusted_packet_sequence_id {
uint32 trusted_packet_sequence_id = 10;
}
}
message TrackEvent {
repeated uint64 category_iids = 3;
repeated string categories = 22;
repeated DebugAnnotation debug_annotations = 4;
oneof name_field {
uint64 name_iid = 10;
string name = 23;
}
enum Type {
TYPE_UNSPECIFIED = 0;
TYPE_SLICE_BEGIN = 1;
TYPE_SLICE_END = 2;
TYPE_INSTANT = 3;
TYPE_COUNTER = 4;
}
optional Type type = 9;
optional uint64 track_uuid = 11;
optional int64 counter_value = 30;
oneof timestamp {
int64 timestamp_delta_us = 1;
int64 timestamp_absolute_us = 16;
}
oneof thread_time {
int64 thread_time_delta_us = 2;
int64 thread_time_absolute_us = 17;
}
}
message TrackDescriptor {
optional uint64 uuid = 1;
optional uint64 parent_uuid = 5;
optional string name = 2;
optional ThreadDescriptor thread = 4;
optional CounterDescriptor counter = 8;
}
message ThreadDescriptor {
optional int32 pid = 1;
optional int32 tid = 2;
optional string thread_name = 5;
}
message CounterDescriptor {
enum BuiltinCounterType {
COUNTER_UNSPECIFIED = 0;
COUNTER_THREAD_TIME_NS = 1;
COUNTER_THREAD_INSTRUCTION_COUNT = 2;
}
enum Unit {
UNIT_UNSPECIFIED = 0;
UNIT_TIME_NS = 1;
UNIT_COUNT = 2;
UNIT_SIZE_BYTES = 3;
}
optional BuiltinCounterType type = 1;
repeated string categories = 2;
optional Unit unit = 3;
optional int64 unit_multiplier = 4;
optional bool is_incremental = 5;
}
message DebugAnnotation {
message NestedValue {
enum NestedType {
UNSPECIFIED = 0;
DICT = 1;
ARRAY = 2;
}
optional NestedType nested_type = 1;
repeated string dict_keys = 2;
repeated NestedValue dict_values = 3;
repeated NestedValue array_values = 4;
optional int64 int_value = 5;
optional double double_value = 6;
optional bool bool_value = 7;
optional string string_value = 8;
}
oneof name_field {
uint64 name_iid = 1;
string name = 10;
}
oneof value {
bool bool_value = 2;
uint64 uint_value = 3;
int64 int_value = 4;
double double_value = 5;
string string_value = 6;
uint64 pointer_value = 7;
NestedValue nested_value = 8;
}
}

View File

@@ -0,0 +1,9 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:fillColor="#FFFFFFFF"
android:pathData="M22,4.5L16.35,4.5a4.45,4.45 0,0 0,-8.7 0L2,4.5L2,6L3.5,6L4.86,20A2.25,2.25 0,0 0,7.1 22h9.8a2.25,2.25 0,0 0,2.24 -2L20.5,6L22,6ZM12,2.5a3,3 0,0 1,2.82 2L9.18,4.5A3,3 0,0 1,12 2.5ZM17.65,19.83a0.76,0.76 0,0 1,-0.75 0.67L7.1,20.5a0.76,0.76 0,0 1,-0.75 -0.67L5,6L19,6ZM11.25,18L11.25,8h1.5L12.75,18ZM14.5,18L15,8h1.5L16,18ZM8,18 L7.5,8L9,8l0.5,10Z"/>
</vector>