Replace Firebase ML vision with built in face detection.

This commit is contained in:
Alan Evans
2021-01-27 20:00:07 -04:00
committed by Greyson Parrelli
parent 1b448c2bdf
commit 6a45858b4a
6 changed files with 120 additions and 127 deletions

View File

@@ -0,0 +1,98 @@
package org.thoughtcrime.securesms.scribbles;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.PointF;
import android.graphics.RectF;
import androidx.annotation.NonNull;
import com.annimon.stream.Stream;
import org.signal.core.util.logging.Log;
import java.util.List;
import java.util.Locale;
/**
* Detects faces with the built in Android face detection.
*/
final class AndroidFaceDetector implements FaceDetector {
private static final String TAG = Log.tag(AndroidFaceDetector.class);
private static final int MAX_FACES = 20;
@Override
public List<Face> detect(@NonNull Bitmap source) {
long startTime = System.currentTimeMillis();
Log.d(TAG, String.format(Locale.US, "Bitmap format is %dx%d %s", source.getWidth(), source.getHeight(), source.getConfig()));
boolean createBitmap = source.getConfig() != Bitmap.Config.RGB_565 || source.getWidth() % 2 != 0;
Bitmap bitmap;
if (createBitmap) {
Log.d(TAG, "Changing colour format to 565, with even width");
bitmap = Bitmap.createBitmap(source.getWidth() & ~0x1, source.getHeight(), Bitmap.Config.RGB_565);
new Canvas(bitmap).drawBitmap(source, 0, 0, null);
} else {
bitmap = source;
}
try {
android.media.FaceDetector faceDetector = new android.media.FaceDetector(bitmap.getWidth(), bitmap.getHeight(), MAX_FACES);
android.media.FaceDetector.Face[] faces = new android.media.FaceDetector.Face[MAX_FACES];
int foundFaces = faceDetector.findFaces(bitmap, faces);
Log.d(TAG, String.format(Locale.US, "Found %d faces", foundFaces));
return Stream.of(faces)
.limit(foundFaces)
.map(AndroidFaceDetector::faceToFace)
.toList();
} finally {
if (createBitmap) {
bitmap.recycle();
}
Log.d(TAG, "Finished in " + (System.currentTimeMillis() - startTime) + " ms");
}
}
private static Face faceToFace(@NonNull android.media.FaceDetector.Face face) {
PointF point = new PointF();
face.getMidPoint(point);
float halfWidth = face.eyesDistance() * 1.4f;
float yOffset = face.eyesDistance() * 0.4f;
RectF bounds = new RectF(point.x - halfWidth, point.y - halfWidth + yOffset, point.x + halfWidth, point.y + halfWidth + yOffset);
return new DefaultFace(bounds, face.confidence());
}
private static class DefaultFace implements Face {
private final RectF bounds;
private final float certainty;
public DefaultFace(@NonNull RectF bounds, float confidence) {
this.bounds = bounds;
this.certainty = confidence;
}
@Override
public RectF getBounds() {
return bounds;
}
@Override
public Class<? extends FaceDetector> getDetectorClass() {
return AndroidFaceDetector.class;
}
@Override
public float getConfidence() {
return certainty;
}
}
}

View File

@@ -3,8 +3,18 @@ package org.thoughtcrime.securesms.scribbles;
import android.graphics.Bitmap;
import android.graphics.RectF;
import androidx.annotation.NonNull;
import java.util.List;
interface FaceDetector {
List<RectF> detect(Bitmap bitmap);
List<Face> detect(@NonNull Bitmap bitmap);
interface Face {
RectF getBounds();
Class<? extends FaceDetector> getDetectorClass();
float getConfidence();
}
}

View File

@@ -1,79 +0,0 @@
package org.thoughtcrime.securesms.scribbles;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Build;
import com.annimon.stream.Stream;
import com.google.firebase.ml.vision.FirebaseVision;
import com.google.firebase.ml.vision.common.FirebaseVisionImage;
import com.google.firebase.ml.vision.face.FirebaseVisionFace;
import com.google.firebase.ml.vision.face.FirebaseVisionFaceDetector;
import com.google.firebase.ml.vision.face.FirebaseVisionFaceDetectorOptions;
import org.signal.core.util.logging.Log;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
class FirebaseFaceDetector implements FaceDetector {
private static final String TAG = Log.tag(FirebaseFaceDetector.class);
private static final long MAX_SIZE = 1000 * 1000;
@Override
public List<RectF> detect(Bitmap source) {
long startTime = System.currentTimeMillis();
int performanceMode = getPerformanceMode(source);
Log.d(TAG, "Using performance mode " + performanceMode + " (API " + Build.VERSION.SDK_INT + ", " + source.getWidth() + "x" + source.getHeight() + ")");
FirebaseVisionFaceDetectorOptions options = new FirebaseVisionFaceDetectorOptions.Builder()
.setPerformanceMode(performanceMode)
.setMinFaceSize(0.05f)
.setContourMode(FirebaseVisionFaceDetectorOptions.NO_CONTOURS)
.setLandmarkMode(FirebaseVisionFaceDetectorOptions.NO_LANDMARKS)
.setClassificationMode(FirebaseVisionFaceDetectorOptions.NO_CLASSIFICATIONS)
.build();
FirebaseVisionImage image = FirebaseVisionImage.fromBitmap(source);
List<RectF> output = new ArrayList<>();
try (FirebaseVisionFaceDetector detector = FirebaseVision.getInstance().getVisionFaceDetector(options)) {
CountDownLatch latch = new CountDownLatch(1);
detector.detectInImage(image)
.addOnSuccessListener(firebaseVisionFaces -> {
output.addAll(Stream.of(firebaseVisionFaces)
.map(FirebaseVisionFace::getBoundingBox)
.map(r -> new RectF(r.left, r.top, r.right, r.bottom))
.toList());
latch.countDown();
})
.addOnFailureListener(e -> latch.countDown());
latch.await(15, TimeUnit.SECONDS);
} catch (IOException e) {
Log.w(TAG, "Failed to close!", e);
} catch (InterruptedException e) {
Log.w(TAG, e);
}
Log.d(TAG, "Finished in " + (System.currentTimeMillis() - startTime) + " ms");
return output;
}
private static int getPerformanceMode(Bitmap source) {
if (Build.VERSION.SDK_INT < 28) {
return FirebaseVisionFaceDetectorOptions.FAST;
}
return source.getWidth() * source.getHeight() < MAX_SIZE ? FirebaseVisionFaceDetectorOptions.ACCURATE
: FirebaseVisionFaceDetectorOptions.FAST;
}
}

View File

@@ -375,7 +375,7 @@ public final class ImageEditorFragment extends Fragment implements ImageEditorHu
if (mainImage.getRenderer() != null) {
Bitmap bitmap = ((UriGlideRenderer) mainImage.getRenderer()).getBitmap();
if (bitmap != null) {
FaceDetector detector = new FirebaseFaceDetector();
FaceDetector detector = new AndroidFaceDetector();
Point size = model.getOutputSizeMaxWidth(1000);
Bitmap render = model.render(ApplicationDependencies.getApplication(), size);
@@ -486,7 +486,7 @@ public final class ImageEditorFragment extends Fragment implements ImageEditorHu
}
private void renderFaceBlurs(@NonNull FaceDetectionResult result) {
List<RectF> faces = result.rects;
List<FaceDetector.Face> faces = result.faces;
if (faces.isEmpty()) {
cachedFaceDetection = null;
@@ -497,12 +497,12 @@ public final class ImageEditorFragment extends Fragment implements ImageEditorHu
Matrix faceMatrix = new Matrix();
for (RectF face : faces) {
FaceBlurRenderer faceBlurRenderer = new FaceBlurRenderer();
for (FaceDetector.Face face : faces) {
Renderer faceBlurRenderer = new FaceBlurRenderer();
EditorElement element = new EditorElement(faceBlurRenderer, EditorModel.Z_MASK);
Matrix localMatrix = element.getLocalMatrix();
faceMatrix.setRectToRect(Bounds.FULL_BOUNDS, face, Matrix.ScaleToFit.FILL);
faceMatrix.setRectToRect(Bounds.FULL_BOUNDS, face.getBounds(), Matrix.ScaleToFit.FILL);
localMatrix.set(result.position);
localMatrix.preConcat(faceMatrix);
@@ -574,11 +574,11 @@ public final class ImageEditorFragment extends Fragment implements ImageEditorHu
}
private static class FaceDetectionResult {
private final List<RectF> rects;
private final Matrix position;
private final List<FaceDetector.Face> faces;
private final Matrix position;
private FaceDetectionResult(@NonNull List<RectF> rects, @NonNull Point imageSize, @NonNull Matrix position) {
this.rects = rects;
private FaceDetectionResult(@NonNull List<FaceDetector.Face> faces, @NonNull Point imageSize, @NonNull Matrix position) {
this.faces = faces;
this.position = new Matrix(position);
Matrix imageProjectionMatrix = new Matrix();