diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d11afa180d..0e6d1bc95a 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -6,15 +6,12 @@ import "package:logging/logging.dart"; import "package:ml_linalg/linalg.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; -// import "package:photos/events/files_updated_event.dart"; -// import "package:photos/events/local_photos_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; -import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; @@ -555,17 +552,22 @@ class ClusterFeedbackService { // Take the embeddings from the person's clusters in one big list and sample from it final List personEmbeddingsProto = []; for (final clusterID in personClusters) { - final Iterable embedings = + final Iterable embeddings = await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); - personEmbeddingsProto.addAll(embedings); + personEmbeddingsProto.addAll(embeddings); } final List sampledEmbeddingsProto = _randomSampleWithoutReplacement( personEmbeddingsProto, sampleSize, ); - final List> sampledEmbeddings = sampledEmbeddingsProto - .map((embedding) => EVector.fromBuffer(embedding).values) + final List sampledEmbeddings = sampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) .toList(growable: false); // Find the actual closest clusters for the person using median @@ -581,16 +583,20 @@ class ClusterFeedbackService { otherEmbeddingsProto, sampleSize, ); - final List> sampledOtherEmbeddings = - sampledOtherEmbeddingsProto - .map((embedding) => EVector.fromBuffer(embedding).values) - .toList(growable: false); + final List sampledOtherEmbeddings = sampledOtherEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); // Calculate distances and find the median final List distances = []; for (final otherEmbedding in sampledOtherEmbeddings) { for (final embedding in sampledEmbeddings) { - distances.add(cosineDistForNormVectors(embedding, otherEmbedding)); + distances.add(1 - embedding.dot(otherEmbedding)); } } distances.sort(); @@ -671,8 +677,9 @@ class ClusterFeedbackService { if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { allClusterIds.remove(id); clusterAvg[id] = Vector.fromList( - EVector.fromBuffer(clusterToSummary[id]!.$1).values, - dtype: DType.float32,); + EVector.fromBuffer(clusterToSummary[id]!.$1).values, + dtype: DType.float32, + ); alreadyUpdatedClustersCnt++; } if (allClusterIdsToCountMap[id]! < minClusterSize) { @@ -738,10 +745,12 @@ class ClusterFeedbackService { for (final clusterID in clusterEmbeddings.keys) { final Iterable embeddings = clusterEmbeddings[clusterID]!; - final Iterable vectors = embeddings.map((e) => Vector.fromList( - EVector.fromBuffer(e).values, - dtype: DType.float32, - ),); + final Iterable vectors = embeddings.map( + (e) => Vector.fromList( + EVector.fromBuffer(e).values, + dtype: DType.float32, + ), + ); final avg = vectors.reduce((a, b) => a + b) / vectors.length; final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer(); updatesForClusterSummary[clusterID] = @@ -908,16 +917,16 @@ class ClusterFeedbackService { final personEmbeddingsCount = personClusters .map((e) => personClusterToSummary[e]!.$2) .reduce((a, b) => a + b); - final List personAvg = List.filled(192, 0); + Vector personAvg = Vector.filled(192, 0); for (final personClusterID in personClusters) { final personClusterBlob = personClusterToSummary[personClusterID]!.$1; - final personClusterAvg = EVector.fromBuffer(personClusterBlob).values; + final personClusterAvg = Vector.fromList( + EVector.fromBuffer(personClusterBlob).values, + dtype: DType.float32, + ); final clusterWeight = personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; - for (int i = 0; i < personClusterAvg.length; i++) { - personAvg[i] += personClusterAvg[i] * - clusterWeight; // Weighted sum of the cluster averages - } + personAvg += personClusterAvg * clusterWeight; } w?.log('calculated person avg'); @@ -933,16 +942,22 @@ class ClusterFeedbackService { final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces( faceIDs, ); + final faceIdToVectorMap = faceIdToEmbeddingMap.map( + (key, value) => MapEntry( + key, + Vector.fromList( + EVector.fromBuffer(value).values, + dtype: DType.float32, + ), + ), + ); w?.log( 'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID', ); final fileIdToDistanceMap = {}; - for (final entry in faceIdToEmbeddingMap.entries) { + for (final entry in faceIdToVectorMap.entries) { fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = - cosineDistForNormVectors( - personAvg, - EVector.fromBuffer(entry.value).values, - ); + 1 - personAvg.dot(entry.value); } w?.log('calculated distances for cluster $clusterID'); suggestion.filesInCluster.sort((b, a) {