[mob][photos] Use SIMD in sorting suggestions too

This commit is contained in:
laurenspriem 2024-04-24 16:19:10 +05:30
parent e829f7b62f
commit 3806ee3232

View File

@ -6,15 +6,12 @@ import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart"; import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart"; import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart"; import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart";
// import "package:photos/events/local_photos_updated_event.dart";
import "package:photos/events/people_changed_event.dart"; import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart"; import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart"; import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart"; import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -555,17 +552,22 @@ class ClusterFeedbackService {
// Take the embeddings from the person's clusters in one big list and sample from it // Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = []; final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) { for (final clusterID in personClusters) {
final Iterable<Uint8List> embedings = final Iterable<Uint8List> embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embedings); personEmbeddingsProto.addAll(embeddings);
} }
final List<Uint8List> sampledEmbeddingsProto = final List<Uint8List> sampledEmbeddingsProto =
_randomSampleWithoutReplacement( _randomSampleWithoutReplacement(
personEmbeddingsProto, personEmbeddingsProto,
sampleSize, sampleSize,
); );
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values) .map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false); .toList(growable: false);
// Find the actual closest clusters for the person using median // Find the actual closest clusters for the person using median
@ -581,16 +583,20 @@ class ClusterFeedbackService {
otherEmbeddingsProto, otherEmbeddingsProto,
sampleSize, sampleSize,
); );
final List<List<double>> sampledOtherEmbeddings = final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
sampledOtherEmbeddingsProto .map(
.map((embedding) => EVector.fromBuffer(embedding).values) (embedding) => Vector.fromList(
.toList(growable: false); EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
// Calculate distances and find the median // Calculate distances and find the median
final List<double> distances = []; final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) { for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) { for (final embedding in sampledEmbeddings) {
distances.add(cosineDistForNormVectors(embedding, otherEmbedding)); distances.add(1 - embedding.dot(otherEmbedding));
} }
} }
distances.sort(); distances.sort();
@ -671,8 +677,9 @@ class ClusterFeedbackService {
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id); allClusterIds.remove(id);
clusterAvg[id] = Vector.fromList( clusterAvg[id] = Vector.fromList(
EVector.fromBuffer(clusterToSummary[id]!.$1).values, EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,); dtype: DType.float32,
);
alreadyUpdatedClustersCnt++; alreadyUpdatedClustersCnt++;
} }
if (allClusterIdsToCountMap[id]! < minClusterSize) { if (allClusterIdsToCountMap[id]! < minClusterSize) {
@ -738,10 +745,12 @@ class ClusterFeedbackService {
for (final clusterID in clusterEmbeddings.keys) { for (final clusterID in clusterEmbeddings.keys) {
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!; final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList( final Iterable<Vector> vectors = embeddings.map(
EVector.fromBuffer(e).values, (e) => Vector.fromList(
dtype: DType.float32, EVector.fromBuffer(e).values,
),); dtype: DType.float32,
),
);
final avg = vectors.reduce((a, b) => a + b) / vectors.length; final avg = vectors.reduce((a, b) => a + b) / vectors.length;
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer(); final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] = updatesForClusterSummary[clusterID] =
@ -908,16 +917,16 @@ class ClusterFeedbackService {
final personEmbeddingsCount = personClusters final personEmbeddingsCount = personClusters
.map((e) => personClusterToSummary[e]!.$2) .map((e) => personClusterToSummary[e]!.$2)
.reduce((a, b) => a + b); .reduce((a, b) => a + b);
final List<double> personAvg = List.filled(192, 0); Vector personAvg = Vector.filled(192, 0);
for (final personClusterID in personClusters) { for (final personClusterID in personClusters) {
final personClusterBlob = personClusterToSummary[personClusterID]!.$1; final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values; final personClusterAvg = Vector.fromList(
EVector.fromBuffer(personClusterBlob).values,
dtype: DType.float32,
);
final clusterWeight = final clusterWeight =
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
for (int i = 0; i < personClusterAvg.length; i++) { personAvg += personClusterAvg * clusterWeight;
personAvg[i] += personClusterAvg[i] *
clusterWeight; // Weighted sum of the cluster averages
}
} }
w?.log('calculated person avg'); w?.log('calculated person avg');
@ -933,16 +942,22 @@ class ClusterFeedbackService {
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces( final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
faceIDs, faceIDs,
); );
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
(key, value) => MapEntry(
key,
Vector.fromList(
EVector.fromBuffer(value).values,
dtype: DType.float32,
),
),
);
w?.log( w?.log(
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID', 'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
); );
final fileIdToDistanceMap = {}; final fileIdToDistanceMap = {};
for (final entry in faceIdToEmbeddingMap.entries) { for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
cosineDistForNormVectors( 1 - personAvg.dot(entry.value);
personAvg,
EVector.fromBuffer(entry.value).values,
);
} }
w?.log('calculated distances for cluster $clusterID'); w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) { suggestion.filesInCluster.sort((b, a) {