From bd495c386094d34bce1796b340ca98203fe43d59 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Mon, 6 May 2024 17:16:58 +0530 Subject: [PATCH] [mob][photos] Assert that embeddings are always normalized --- .../face_clustering/face_clustering_service.dart | 14 ++++++++++++-- .../face_ml/feedback/cluster_feedback.dart | 5 +++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 27df434bce..0e19ab4597 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -482,6 +482,14 @@ class FaceClusteringService { ); } + // Assert that the embeddings are normalized + for (final faceInfo in faceInfos) { + if (faceInfo.vEmbedding != null) { + final norm = faceInfo.vEmbedding!.norm(); + assert((norm - 1.0).abs() < 1e-5); + } + } + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first if (fileIDToCreationTime != null) { faceInfos.sort((a, b) { @@ -670,8 +678,10 @@ class FaceClusteringService { } else { final newMeanVector = newEmbeddings.reduce((a, b) => a + b); final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); - newClusterSummaries[clusterId] = - (EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount); + newClusterSummaries[clusterId] = ( + EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), + newCount + ); } } log( diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d8faa8c0ee..e9d244fba0 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -808,6 +808,11 @@ class ClusterFeedbackService { final alreadyUpdatedClustersCnt = serializationEmbeddings.$4; final smallerClustersCnt = serializationEmbeddings.$5; + // Assert that all existing clusterAvg are normalized + for (final avg in clusterAvg.values) { + assert((avg.norm() - 1.0).abs() < 1e-5); + } + w?.log( 'serialization of embeddings', );