[mob] merge mobile_face to fix_face_thumbnail

This commit is contained in:
ashilkn 2024-04-26 11:32:33 +05:30
commit a577611e65
14 changed files with 534 additions and 226 deletions

View File

@ -98,7 +98,7 @@ class FaceMLDataDB {
} }
} }
Future<void> updateClusterIdToFaceId( Future<void> updateFaceIdToClusterId(
Map<String, int> faceIDToClusterID, Map<String, int> faceIDToClusterID,
) async { ) async {
final db = await instance.database; final db = await instance.database;
@ -146,8 +146,8 @@ class FaceMLDataDB {
} }
Future<Map<int, int>> clusterIdToFaceCount() async { Future<Map<int, int>> clusterIdToFaceCount() async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcClusterID, COUNT(*) as count FROM $faceClustersTable where $fcClusterID IS NOT NULL GROUP BY $fcClusterID ', 'SELECT $fcClusterID, COUNT(*) as count FROM $faceClustersTable where $fcClusterID IS NOT NULL GROUP BY $fcClusterID ',
); );
final Map<int, int> result = {}; final Map<int, int> result = {};
@ -158,15 +158,15 @@ class FaceMLDataDB {
} }
Future<Set<int>> getPersonIgnoredClusters(String personID) async { Future<Set<int>> getPersonIgnoredClusters(String personID) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
// find out clusterIds that are assigned to other persons using the clusters table // find out clusterIds that are assigned to other persons using the clusters table
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
[personID], [personID],
); );
final Set<int> ignoredClusterIDs = final Set<int> ignoredClusterIDs =
maps.map((e) => e[clusterIDColumn] as int).toSet(); maps.map((e) => e[clusterIDColumn] as int).toSet();
final List<Map<String, dynamic>> rejectMaps = await db.rawQuery( final List<Map<String, dynamic>> rejectMaps = await db.getAll(
'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
[personID], [personID],
); );
@ -176,8 +176,8 @@ class FaceMLDataDB {
} }
Future<Set<int>> getPersonClusterIDs(String personID) async { Future<Set<int>> getPersonClusterIDs(String personID) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?',
[personID], [personID],
); );
@ -197,8 +197,8 @@ class FaceMLDataDB {
int clusterID, { int clusterID, {
int? limit, int? limit,
}) async { }) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}', 'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}',
[clusterID], [clusterID],
); );
@ -209,7 +209,7 @@ class FaceMLDataDB {
Iterable<int> clusterIDs, { Iterable<int> clusterIDs, {
int? limit, int? limit,
}) async { }) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final Map<int, List<Uint8List>> result = {}; final Map<int, List<Uint8List>> result = {};
final selectQuery = ''' final selectQuery = '''
@ -220,7 +220,7 @@ class FaceMLDataDB {
${limit != null ? 'LIMIT $limit' : ''} ${limit != null ? 'LIMIT $limit' : ''}
'''; ''';
final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery); final List<Map<String, dynamic>> maps = await db.getAll(selectQuery);
for (final map in maps) { for (final map in maps) {
final clusterID = map[fcClusterID] as int; final clusterID = map[fcClusterID] as int;
@ -321,8 +321,8 @@ class FaceMLDataDB {
} }
Future<Face?> getFaceForFaceID(String faceID) async { Future<Face?> getFaceForFaceID(String faceID) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final result = await db.rawQuery( final result = await db.getAll(
'SELECT * FROM $facesTable where $faceIDColumn = ?', 'SELECT * FROM $facesTable where $faceIDColumn = ?',
[faceID], [faceID],
); );
@ -332,6 +332,36 @@ class FaceMLDataDB {
return mapRowToFace(result.first); return mapRowToFace(result.first);
} }
Future<Map<int, Iterable<String>>> getClusterToFaceIDs(
Set<int> clusterIDs,
) async {
final db = await instance.sqliteAsyncDB;
final Map<int, List<String>> result = {};
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})',
);
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
result.putIfAbsent(clusterID, () => <String>[]).add(faceID);
}
return result;
}
Future<Map<int, Iterable<String>>> getAllClusterIdToFaceIDs() async {
final db = await instance.sqliteAsyncDB;
final Map<int, List<String>> result = {};
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable',
);
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
result.putIfAbsent(clusterID, () => <String>[]).add(faceID);
}
return result;
}
Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async { Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async {
final db = await instance.sqliteAsyncDB; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll( final List<Map<String, dynamic>> maps = await db.getAll(
@ -390,8 +420,8 @@ class FaceMLDataDB {
Future<Map<String, int?>> getFaceIdsToClusterIds( Future<Map<String, int?>> getFaceIdsToClusterIds(
Iterable<String> faceIds, Iterable<String> faceIds,
) async { ) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})', 'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})',
); );
final Map<String, int?> result = {}; final Map<String, int?> result = {};
@ -403,8 +433,8 @@ class FaceMLDataDB {
Future<Map<int, Set<int>>> getFileIdToClusterIds() async { Future<Map<int, Set<int>>> getFileIdToClusterIds() async {
final Map<int, Set<int>> result = {}; final Map<int, Set<int>> result = {};
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable',
); );
@ -761,9 +791,9 @@ class FaceMLDataDB {
// for a given personID, return a map of clusterID to fileIDs using join query // for a given personID, return a map of clusterID to fileIDs using join query
Future<Map<int, Set<int>>> getFileIdToClusterIDSet(String personID) { Future<Map<int, Set<int>>> getFileIdToClusterIDSet(String personID) {
final db = instance.database; final db = instance.sqliteAsyncDB;
return db.then((db) async { return db.then((db) async {
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable ' 'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable '
'INNER JOIN $clusterPersonTable ' 'INNER JOIN $clusterPersonTable '
'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
@ -784,9 +814,9 @@ class FaceMLDataDB {
Future<Map<int, Set<int>>> getFileIdToClusterIDSetForCluster( Future<Map<int, Set<int>>> getFileIdToClusterIDSetForCluster(
Set<int> clusterIDs, Set<int> clusterIDs,
) { ) {
final db = instance.database; final db = instance.sqliteAsyncDB;
return db.then((db) async { return db.then((db) async {
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable ' 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable '
'WHERE $fcClusterID IN (${clusterIDs.join(",")})', 'WHERE $fcClusterID IN (${clusterIDs.join(",")})',
); );
@ -846,9 +876,26 @@ class FaceMLDataDB {
return result; return result;
} }
Future<Map<int, (Uint8List, int)>> getClusterToClusterSummary(
Iterable<int> clusterIDs,
) async {
final db = await instance.sqliteAsyncDB;
final Map<int, (Uint8List, int)> result = {};
final rows = await db.getAll(
'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})',
);
for (final r in rows) {
final id = r[clusterIDColumn] as int;
final avg = r[avgColumn] as Uint8List;
final count = r[countColumn] as int;
result[id] = (avg, count);
}
return result;
}
Future<Map<int, String>> getClusterIDToPersonID() async { Future<Map<int, String>> getClusterIDToPersonID() async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable',
); );
final Map<int, String> result = {}; final Map<int, String> result = {};

View File

@ -61,7 +61,7 @@ class EntityService {
}) async { }) async {
final key = await getOrCreateEntityKey(type); final key = await getOrCreateEntityKey(type);
final encryptedKeyData = await CryptoUtil.encryptChaCha( final encryptedKeyData = await CryptoUtil.encryptChaCha(
utf8.encode(plainText) as Uint8List, utf8.encode(plainText),
key, key,
); );
final String encryptedData = final String encryptedData =

View File

@ -1,5 +1,18 @@
import 'dart:math' show sqrt; import 'dart:math' show sqrt;
import "package:ml_linalg/vector.dart";
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: This assumes both vectors are already normalized!
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}
return 1 - vector1.dot(vector2);
}
/// Calculates the cosine distance between two embeddings/vectors. /// Calculates the cosine distance between two embeddings/vectors.
/// ///
/// Throws an ArgumentError if the vectors are of different lengths or /// Throws an ArgumentError if the vectors are of different lengths or

View File

@ -69,7 +69,7 @@ class FaceClusteringService {
bool isRunning = false; bool isRunning = false;
static const kRecommendedDistanceThreshold = 0.24; static const kRecommendedDistanceThreshold = 0.24;
static const kConservativeDistanceThreshold = 0.06; static const kConservativeDistanceThreshold = 0.16;
// singleton pattern // singleton pattern
FaceClusteringService._privateConstructor(); FaceClusteringService._privateConstructor();
@ -560,10 +560,10 @@ class FaceClusteringService {
for (int j = i - 1; j >= 0; j--) { for (int j = i - 1; j >= 0; j--) {
late double distance; late double distance;
if (sortedFaceInfos[i].vEmbedding != null) { if (sortedFaceInfos[i].vEmbedding != null) {
distance = 1.0 - distance = cosineDistanceSIMD(
sortedFaceInfos[i] sortedFaceInfos[i].vEmbedding!,
.vEmbedding! sortedFaceInfos[j].vEmbedding!,
.dot(sortedFaceInfos[j].vEmbedding!); );
} else { } else {
distance = cosineDistForNormVectors( distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!, sortedFaceInfos[i].embedding!,
@ -804,8 +804,10 @@ class FaceClusteringService {
double closestDistance = double.infinity; double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) { for (int j = 0; j < totalFaces; j++) {
if (i == j) continue; if (i == j) continue;
final double distance = final double distance = cosineDistanceSIMD(
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!); faceInfos[i].vEmbedding!,
faceInfos[j].vEmbedding!,
);
if (distance < closestDistance) { if (distance < closestDistance) {
closestDistance = distance; closestDistance = distance;
closestIdx = j; closestIdx = j;
@ -855,10 +857,10 @@ class FaceClusteringService {
for (int i = 0; i < clusterIds.length; i++) { for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) { for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue; if (i == j) continue;
final double newDistance = 1.0 - final double newDistance = cosineDistanceSIMD(
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot( clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1, clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
); );
if (newDistance < distance) { if (newDistance < distance) {
distance = newDistance; distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]); clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
@ -959,9 +961,9 @@ class FaceClusteringService {
// Run the DBSCAN clustering // Run the DBSCAN clustering
final List<List<int>> clusterOutput = dbscan.run(embeddings); final List<List<int>> clusterOutput = dbscan.run(embeddings);
final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput // final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx]).toList()) // .map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
.toList(); // .toList();
final List<List<String>> clusteredFaceIDs = clusterOutput final List<List<String>> clusteredFaceIDs = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList()) .map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
.toList(); .toList();

View File

@ -1,8 +1,8 @@
import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
/// Blur detection threshold /// Blur detection threshold
const kLaplacianHardThreshold = 15; const kLaplacianHardThreshold = 10;
const kLaplacianSoftThreshold = 100; const kLaplacianSoftThreshold = 50;
const kLaplacianVerySoftThreshold = 200; const kLaplacianVerySoftThreshold = 200;
/// Default blur value /// Default blur value

View File

@ -350,7 +350,7 @@ class FaceMlService {
} }
await FaceMLDataDB.instance await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!); .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info( _logger.info(
@ -403,7 +403,7 @@ class FaceMlService {
'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB', 'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
); );
await FaceMLDataDB.instance await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!); .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '

View File

@ -1,19 +1,18 @@
import 'dart:developer' as dev; import 'dart:developer' as dev;
import "dart:math" show Random; import "dart:math" show Random, min;
import "package:flutter/foundation.dart"; import "package:flutter/foundation.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart"; import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart"; import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart";
// import "package:photos/events/local_photos_updated_event.dart";
import "package:photos/events/people_changed_event.dart"; import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart"; import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart"; import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart"; import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -25,12 +24,14 @@ class ClusterSuggestion {
final double distancePersonToCluster; final double distancePersonToCluster;
final bool usedOnlyMeanForSuggestion; final bool usedOnlyMeanForSuggestion;
final List<EnteFile> filesInCluster; final List<EnteFile> filesInCluster;
final List<String> faceIDsInCluster;
ClusterSuggestion( ClusterSuggestion(
this.clusterIDToMerge, this.clusterIDToMerge,
this.distancePersonToCluster, this.distancePersonToCluster,
this.usedOnlyMeanForSuggestion, this.usedOnlyMeanForSuggestion,
this.filesInCluster, this.filesInCluster,
this.faceIDsInCluster,
); );
} }
@ -60,19 +61,27 @@ class ClusterFeedbackService {
bool extremeFilesFirst = true, bool extremeFilesFirst = true,
}) async { }) async {
_logger.info( _logger.info(
'getClusterFilesForPersonID ${kDebugMode ? person.data.name : person.remoteID}', 'getSuggestionForPerson ${kDebugMode ? person.data.name : person.remoteID}',
); );
try { try {
// Get the suggestions for the person using centroids and median // Get the suggestions for the person using centroids and median
final List<(int, double, bool)> suggestClusterIds = final startTime = DateTime.now();
final List<(int, double, bool)> foundSuggestions =
await _getSuggestions(person); await _getSuggestions(person);
final findSuggestionsTime = DateTime.now();
_logger.info(
'getSuggestionForPerson `_getSuggestions`: Found ${foundSuggestions.length} suggestions in ${findSuggestionsTime.difference(startTime).inMilliseconds} ms',
);
// Get the files for the suggestions // Get the files for the suggestions
final suggestionClusterIDs = foundSuggestions.map((e) => e.$1).toSet();
final Map<int, Set<int>> fileIdToClusterID = final Map<int, Set<int>> fileIdToClusterID =
await FaceMLDataDB.instance.getFileIdToClusterIDSetForCluster( await FaceMLDataDB.instance.getFileIdToClusterIDSetForCluster(
suggestClusterIds.map((e) => e.$1).toSet(), suggestionClusterIDs,
); );
final clusterIdToFaceIDs =
await FaceMLDataDB.instance.getClusterToFaceIDs(suggestionClusterIDs);
final Map<int, List<EnteFile>> clusterIDToFiles = {}; final Map<int, List<EnteFile>> clusterIDToFiles = {};
final allFiles = await SearchService.instance.getAllFiles(); final allFiles = await SearchService.instance.getAllFiles();
for (final f in allFiles) { for (final f in allFiles) {
@ -89,25 +98,31 @@ class ClusterFeedbackService {
} }
} }
final List<ClusterSuggestion> clusterIdAndFiles = []; final List<ClusterSuggestion> finalSuggestions = [];
for (final clusterSuggestion in suggestClusterIds) { for (final clusterSuggestion in foundSuggestions) {
if (clusterIDToFiles.containsKey(clusterSuggestion.$1)) { if (clusterIDToFiles.containsKey(clusterSuggestion.$1)) {
clusterIdAndFiles.add( finalSuggestions.add(
ClusterSuggestion( ClusterSuggestion(
clusterSuggestion.$1, clusterSuggestion.$1,
clusterSuggestion.$2, clusterSuggestion.$2,
clusterSuggestion.$3, clusterSuggestion.$3,
clusterIDToFiles[clusterSuggestion.$1]!, clusterIDToFiles[clusterSuggestion.$1]!,
clusterIdToFaceIDs[clusterSuggestion.$1]!.toList(),
), ),
); );
} }
} }
final getFilesTime = DateTime.now();
final sortingStartTime = DateTime.now();
if (extremeFilesFirst) { if (extremeFilesFirst) {
await _sortSuggestionsOnDistanceToPerson(person, clusterIdAndFiles); await _sortSuggestionsOnDistanceToPerson(person, finalSuggestions);
} }
_logger.info(
'getSuggestionForPerson post-processing suggestions took ${DateTime.now().difference(findSuggestionsTime).inMilliseconds} ms, of which sorting took ${DateTime.now().difference(sortingStartTime).inMilliseconds} ms and getting files took ${getFilesTime.difference(findSuggestionsTime).inMilliseconds} ms',
);
return clusterIdAndFiles; return finalSuggestions;
} catch (e, s) { } catch (e, s) {
_logger.severe("Error in getClusterFilesForPersonID", e, s); _logger.severe("Error in getClusterFilesForPersonID", e, s);
rethrow; rethrow;
@ -229,13 +244,13 @@ class ClusterFeedbackService {
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log( dev.log(
'existing clusters for ${p.data.name} are $personClusters', '${p.data.name} has ${personClusters.length} existing clusters',
name: "ClusterFeedbackService", name: "ClusterFeedbackService",
); );
// Get and update the cluster summary to get the avg (centroid) and count // Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg( final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap, allClusterIdsToCountMap,
ignoredClusters, ignoredClusters,
); );
@ -397,7 +412,7 @@ class ClusterFeedbackService {
final newClusterID = startClusterID + blurValue ~/ 10; final newClusterID = startClusterID + blurValue ~/ 10;
faceIdToCluster[faceID] = newClusterID; faceIdToCluster[faceID] = newClusterID;
} }
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); await FaceMLDataDB.instance.updateFaceIdToClusterId(faceIdToCluster);
Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire(PeopleChangedEvent());
} catch (e, s) { } catch (e, s) {
@ -437,69 +452,81 @@ class ClusterFeedbackService {
Future<List<(int, double, bool)>> _getSuggestions( Future<List<(int, double, bool)>> _getSuggestions(
PersonEntity p, { PersonEntity p, {
int sampleSize = 50, int sampleSize = 50,
double maxMedianDistance = 0.65, double maxMedianDistance = 0.62,
double goodMedianDistance = 0.55, double goodMedianDistance = 0.55,
double maxMeanDistance = 0.65, double maxMeanDistance = 0.65,
double goodMeanDistance = 0.5, double goodMeanDistance = 0.50,
}) async { }) async {
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
// Get all the cluster data // Get all the cluster data
final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance; final faceMlDb = FaceMLDataDB.instance;
// final Map<int, List<(int, double)>> suggestions = {};
final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount(); final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log( final personFaceIDs =
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms', await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID);
name: "getSuggestionsUsingMedian", final personFileIDs = personFaceIDs.map(getFileIdFromFaceId).toSet();
w?.log(
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data done',
); );
final allClusterIdToFaceIDs =
await FaceMLDataDB.instance.getAllClusterIdToFaceIDs();
w?.log('getAllClusterIdToFaceIDs done');
// First only do a simple check on the big clusters // First only do a simple check on the big clusters, if the person does not have small clusters yet
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); final smallestPersonClusterSize = personClusters
final Map<int, List<double>> clusterAvgBigClusters = .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
await _getUpdateClusterAvg( .reduce((value, element) => min(value, element));
allClusterIdsToCountMap, final checkSizes = [20, kMinimumClusterSizeSearchResult, 10, 5, 1];
ignoredClusters, late Map<int, Vector> clusterAvgBigClusters;
minClusterSize: kMinimumClusterSizeSearchResult, final List<(int, double)> suggestionsMean = [];
); for (final minimumSize in checkSizes.toSet()) {
dev.log( // if (smallestPersonClusterSize >= minimumSize) {
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', clusterAvgBigClusters = await _getUpdateClusterAvg(
); allClusterIdsToCountMap,
final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean( ignoredClusters,
clusterAvgBigClusters, minClusterSize: minimumSize,
personClusters, );
ignoredClusters, w?.log(
goodMeanDistance, 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize',
); );
if (suggestionsMeanBigClusters.isNotEmpty) { final List<(int, double)> suggestionsMeanBigClusters =
return suggestionsMeanBigClusters _calcSuggestionsMean(
.map((e) => (e.$1, e.$2, true)) clusterAvgBigClusters,
.toList(growable: false); personClusters,
} ignoredClusters,
goodMeanDistance,
// Get and update the cluster summary to get the avg (centroid) and count );
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg( w?.log(
allClusterIdsToCountMap, 'Calculate suggestions using mean for ${clusterAvgBigClusters.length} clusters of min size $minimumSize',
ignoredClusters, );
); for (final suggestion in suggestionsMeanBigClusters) {
dev.log( // Skip suggestions that have a high overlap with the person's files
'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', final suggestionSet = allClusterIdToFaceIDs[suggestion.$1]!
); .map((faceID) => getFileIdFromFaceId(faceID))
.toSet();
// Find the other cluster candidates based on the mean final overlap = personFileIDs.intersection(suggestionSet);
final List<(int, double)> suggestionsMean = _calcSuggestionsMean( if (overlap.isNotEmpty &&
clusterAvg, ((overlap.length / suggestionSet.length) > 0.5)) {
personClusters, await FaceMLDataDB.instance.captureNotPersonFeedback(
ignoredClusters, personID: p.remoteID,
goodMeanDistance, clusterID: suggestion.$1,
); );
if (suggestionsMean.isNotEmpty) { continue;
return suggestionsMean }
.map((e) => (e.$1, e.$2, true)) suggestionsMean.add(suggestion);
.toList(growable: false); }
if (suggestionsMean.isNotEmpty) {
return suggestionsMean
.map((e) => (e.$1, e.$2, true))
.toList(growable: false);
// }
}
} }
w?.reset();
// Find the other cluster candidates based on the median // Find the other cluster candidates based on the median
final clusterAvg = clusterAvgBigClusters;
final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean( final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
clusterAvg, clusterAvg,
personClusters, personClusters,
@ -522,21 +549,26 @@ class ClusterFeedbackService {
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates", "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
); );
watch.logAndReset("Starting median test"); w?.logAndReset("Starting median test");
// Take the embeddings from the person's clusters in one big list and sample from it // Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = []; final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) { for (final clusterID in personClusters) {
final Iterable<Uint8List> embedings = final Iterable<Uint8List> embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embedings); personEmbeddingsProto.addAll(embeddings);
} }
final List<Uint8List> sampledEmbeddingsProto = final List<Uint8List> sampledEmbeddingsProto =
_randomSampleWithoutReplacement( _randomSampleWithoutReplacement(
personEmbeddingsProto, personEmbeddingsProto,
sampleSize, sampleSize,
); );
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values) .map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false); .toList(growable: false);
// Find the actual closest clusters for the person using median // Find the actual closest clusters for the person using median
@ -552,16 +584,20 @@ class ClusterFeedbackService {
otherEmbeddingsProto, otherEmbeddingsProto,
sampleSize, sampleSize,
); );
final List<List<double>> sampledOtherEmbeddings = final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
sampledOtherEmbeddingsProto .map(
.map((embedding) => EVector.fromBuffer(embedding).values) (embedding) => Vector.fromList(
.toList(growable: false); EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
// Calculate distances and find the median // Calculate distances and find the median
final List<double> distances = []; final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) { for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) { for (final embedding in sampledEmbeddings) {
distances.add(cosineDistForNormVectors(embedding, otherEmbedding)); distances.add(cosineDistanceSIMD(embedding, otherEmbedding));
} }
} }
distances.sort(); distances.sort();
@ -575,7 +611,7 @@ class ClusterFeedbackService {
} }
} }
} }
watch.log("Finished median test"); w?.log("Finished median test");
if (suggestionsMedian.isEmpty) { if (suggestionsMedian.isEmpty) {
_logger.info("No suggestions found using median"); _logger.info("No suggestions found using median");
return []; return [];
@ -607,13 +643,14 @@ class ClusterFeedbackService {
return finalSuggestionsMedian; return finalSuggestionsMedian;
} }
Future<Map<int, List<double>>> _getUpdateClusterAvg( Future<Map<int, Vector>> _getUpdateClusterAvg(
Map<int, int> allClusterIdsToCountMap, Map<int, int> allClusterIdsToCountMap,
Set<int> ignoredClusters, { Set<int> ignoredClusters, {
int minClusterSize = 1, int minClusterSize = 1,
int maxClusterInCurrentRun = 500, int maxClusterInCurrentRun = 500,
int maxEmbeddingToRead = 10000, int maxEmbeddingToRead = 10000,
}) async { }) async {
final w = (kDebugMode ? EnteWatch('_getUpdateClusterAvg') : null)?..start();
final startTime = DateTime.now(); final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance; final faceMlDb = FaceMLDataDB.instance;
_logger.info( _logger.info(
@ -624,16 +661,15 @@ class ClusterFeedbackService {
await faceMlDb.getAllClusterSummary(minClusterSize); await faceMlDb.getAllClusterSummary(minClusterSize);
final Map<int, (Uint8List, int)> updatesForClusterSummary = {}; final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {}; final Map<int, Vector> clusterAvg = {};
dev.log( w?.log(
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms', 'getUpdateClusterAvg database call for getAllClusterSummary',
); );
final allClusterIds = allClusterIdsToCountMap.keys.toSet(); final allClusterIds = allClusterIdsToCountMap.keys.toSet();
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
int smallerClustersCnt = 0; int smallerClustersCnt = 0;
final serializationTime = DateTime.now();
for (final id in allClusterIdsToCountMap.keys) { for (final id in allClusterIdsToCountMap.keys) {
if (ignoredClusters.contains(id)) { if (ignoredClusters.contains(id)) {
allClusterIds.remove(id); allClusterIds.remove(id);
@ -641,7 +677,10 @@ class ClusterFeedbackService {
} }
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id); allClusterIds.remove(id);
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values; clusterAvg[id] = Vector.fromList(
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,
);
alreadyUpdatedClustersCnt++; alreadyUpdatedClustersCnt++;
} }
if (allClusterIdsToCountMap[id]! < minClusterSize) { if (allClusterIdsToCountMap[id]! < minClusterSize) {
@ -649,8 +688,8 @@ class ClusterFeedbackService {
smallerClustersCnt++; smallerClustersCnt++;
} }
} }
dev.log( w?.log(
'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms', 'serialization of embeddings',
); );
_logger.info( _logger.info(
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize', 'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
@ -670,12 +709,7 @@ class ClusterFeedbackService {
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
); );
int indexedInCurrentRun = 0; int indexedInCurrentRun = 0;
final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null; w?.reset();
w?.start();
w?.log(
'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
);
int currentPendingRead = 0; int currentPendingRead = 0;
final List<int> clusterIdsToRead = []; final List<int> clusterIdsToRead = [];
@ -706,19 +740,17 @@ class ClusterFeedbackService {
); );
for (final clusterID in clusterEmbeddings.keys) { for (final clusterID in clusterEmbeddings.keys) {
late List<double> avg; final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!; final Iterable<Vector> vectors = embeddings.map(
final List<double> sum = List.filled(192, 0); (e) => Vector.fromList(
for (final embedding in embedings) { EVector.fromBuffer(e).values,
final data = EVector.fromBuffer(embedding).values; dtype: DType.float32,
for (int i = 0; i < sum.length; i++) { ),
sum[i] += data[i]; );
} final avg = vectors.reduce((a, b) => a + b) / vectors.length;
} final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] = updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length); (avgEmbeddingBuffer, embeddings.length);
// store the intermediate updates // store the intermediate updates
indexedInCurrentRun++; indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) { if (updatesForClusterSummary.length > 100) {
@ -745,20 +777,22 @@ class ClusterFeedbackService {
/// Returns a map of person's clusterID to map of closest clusterID to with disstance /// Returns a map of person's clusterID to map of closest clusterID to with disstance
List<(int, double)> _calcSuggestionsMean( List<(int, double)> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg, Map<int, Vector> clusterAvg,
Set<int> personClusters, Set<int> personClusters,
Set<int> ignoredClusters, Set<int> ignoredClusters,
double maxClusterDistance, { double maxClusterDistance, {
Map<int, int>? allClusterIdsToCountMap, Map<int, int>? allClusterIdsToCountMap,
}) { }) {
final Map<int, List<(int, double)>> suggestions = {}; final Map<int, List<(int, double)>> suggestions = {};
int suggestionCount = 0;
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
for (final otherClusterID in clusterAvg.keys) { for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored // ignore the cluster that belong to the person or is ignored
if (personClusters.contains(otherClusterID) || if (personClusters.contains(otherClusterID) ||
ignoredClusters.contains(otherClusterID)) { ignoredClusters.contains(otherClusterID)) {
continue; continue;
} }
final otherAvg = clusterAvg[otherClusterID]!; final Vector otherAvg = clusterAvg[otherClusterID]!;
int? nearestPersonCluster; int? nearestPersonCluster;
double? minDistance; double? minDistance;
for (final personCluster in personClusters) { for (final personCluster in personClusters) {
@ -766,8 +800,8 @@ class ClusterFeedbackService {
_logger.info('no avg for cluster $personCluster'); _logger.info('no avg for cluster $personCluster');
continue; continue;
} }
final avg = clusterAvg[personCluster]!; final Vector avg = clusterAvg[personCluster]!;
final distance = cosineDistForNormVectors(avg, otherAvg); final distance = cosineDistanceSIMD(avg, otherAvg);
if (distance < maxClusterDistance) { if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) { if (minDistance == null || distance < minDistance) {
minDistance = distance; minDistance = distance;
@ -779,30 +813,35 @@ class ClusterFeedbackService {
suggestions suggestions
.putIfAbsent(nearestPersonCluster, () => []) .putIfAbsent(nearestPersonCluster, () => [])
.add((otherClusterID, minDistance)); .add((otherClusterID, minDistance));
suggestionCount++;
}
if (suggestionCount >= 2000) {
break;
} }
} }
w?.log('calculation inside calcSuggestionsMean');
if (suggestions.isNotEmpty) { if (suggestions.isNotEmpty) {
final List<(int, double)> suggestClusterIds = []; final List<(int, double)> suggestClusterIds = [];
for (final List<(int, double)> suggestion in suggestions.values) { for (final List<(int, double)> suggestion in suggestions.values) {
suggestClusterIds.addAll(suggestion); suggestClusterIds.addAll(suggestion);
} }
List<int>? suggestClusterIdsSizes; suggestClusterIds.sort(
if (allClusterIdsToCountMap != null) { (a, b) => a.$2.compareTo(b.$2),
suggestClusterIds.sort( ); // sort by distance
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!), // List<int>? suggestClusterIdsSizes;
); // if (allClusterIdsToCountMap != null) {
suggestClusterIdsSizes = suggestClusterIds // suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!) // .map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false); // .toList(growable: false);
} // }
final suggestClusterIdsDistances = // final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false); // suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info( _logger.info(
"Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances", "Already found ${suggestClusterIds.length} good suggestions using mean",
); );
return suggestClusterIds; return suggestClusterIds.sublist(0, min(suggestClusterIds.length, 20));
} else { } else {
_logger.info("No suggestions found using mean"); _logger.info("No suggestions found using mean");
return <(int, double)>[]; return <(int, double)>[];
@ -841,56 +880,88 @@ class ClusterFeedbackService {
Future<void> _sortSuggestionsOnDistanceToPerson( Future<void> _sortSuggestionsOnDistanceToPerson(
PersonEntity person, PersonEntity person,
List<ClusterSuggestion> suggestions, List<ClusterSuggestion> suggestions, {
) async { bool onlySortBigSuggestions = true,
}) async {
if (suggestions.isEmpty) { if (suggestions.isEmpty) {
debugPrint('No suggestions to sort'); debugPrint('No suggestions to sort');
return; return;
} }
if (onlySortBigSuggestions) {
final bigSuggestions = suggestions
.where(
(s) => s.filesInCluster.length > kMinimumClusterSizeSearchResult,
)
.toList();
if (bigSuggestions.isEmpty) {
debugPrint('No big suggestions to sort');
return;
}
}
final startTime = DateTime.now(); final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance; final faceMlDb = FaceMLDataDB.instance;
// Get the cluster averages for the person's clusters and the suggestions' clusters // Get the cluster averages for the person's clusters and the suggestions' clusters
final Map<int, (Uint8List, int)> clusterToSummary = final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);
await faceMlDb.getAllClusterSummary(); final Map<int, (Uint8List, int)> personClusterToSummary =
await faceMlDb.getClusterToClusterSummary(personClusters);
final clusterSummaryCallTime = DateTime.now();
// Calculate the avg embedding of the person // Calculate the avg embedding of the person
final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); final w = (kDebugMode ? EnteWatch('sortSuggestions') : null)?..start();
final personEmbeddingsCount = personClusters final personEmbeddingsCount = personClusters
.map((e) => clusterToSummary[e]!.$2) .map((e) => personClusterToSummary[e]!.$2)
.reduce((a, b) => a + b); .reduce((a, b) => a + b);
final List<double> personAvg = List.filled(192, 0); Vector personAvg = Vector.filled(192, 0);
for (final personClusterID in personClusters) { for (final personClusterID in personClusters) {
final personClusterBlob = clusterToSummary[personClusterID]!.$1; final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values; final personClusterAvg = Vector.fromList(
EVector.fromBuffer(personClusterBlob).values,
dtype: DType.float32,
);
final clusterWeight = final clusterWeight =
clusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
for (int i = 0; i < personClusterAvg.length; i++) { personAvg += personClusterAvg * clusterWeight;
personAvg[i] += personClusterAvg[i] *
clusterWeight; // Weighted sum of the cluster averages
}
} }
w?.log('calculated person avg');
// Sort the suggestions based on the distance to the person // Sort the suggestions based on the distance to the person
for (final suggestion in suggestions) { for (final suggestion in suggestions) {
if (onlySortBigSuggestions) {
if (suggestion.filesInCluster.length <= 8) {
continue;
}
}
final clusterID = suggestion.clusterIDToMerge; final clusterID = suggestion.clusterIDToMerge;
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFile( final faceIDs = suggestion.faceIDsInCluster;
suggestion.filesInCluster.map((e) => e.uploadedFileID!).toList(), final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
faceIDs,
);
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
(key, value) => MapEntry(
key,
Vector.fromList(
EVector.fromBuffer(value).values,
dtype: DType.float32,
),
),
);
w?.log(
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
); );
final fileIdToDistanceMap = {}; final fileIdToDistanceMap = {};
for (final entry in faceIdToEmbeddingMap.entries) { for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
cosineDistForNormVectors( cosineDistanceSIMD(personAvg, entry.value);
personAvg,
EVector.fromBuffer(entry.value).values,
);
} }
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) { suggestion.filesInCluster.sort((b, a) {
//todo: review with @laurens, added this to avoid null safety issue //todo: review with @laurens, added this to avoid null safety issue
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1; final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1;
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1; final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1;
return distanceA.compareTo(distanceB); return distanceA.compareTo(distanceB);
}); });
w?.log('sorted files for cluster $clusterID');
debugPrint( debugPrint(
"[${_logger.name}] Sorted suggestions for cluster $clusterID based on distance to person: ${suggestion.filesInCluster.map((e) => fileIdToDistanceMap[e.uploadedFileID]).toList()}", "[${_logger.name}] Sorted suggestions for cluster $clusterID based on distance to person: ${suggestion.filesInCluster.map((e) => fileIdToDistanceMap[e.uploadedFileID]).toList()}",
@ -899,7 +970,7 @@ class ClusterFeedbackService {
final endTime = DateTime.now(); final endTime = DateTime.now();
_logger.info( _logger.info(
"Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions", "Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions, of which ${clusterSummaryCallTime.difference(startTime).inMilliseconds} ms was spent on the cluster summary call",
); );
} }
} }

View File

@ -1,3 +1,4 @@
import "dart:async" show unawaited;
import "dart:convert"; import "dart:convert";
import "package:flutter/foundation.dart"; import "package:flutter/foundation.dart";
@ -102,10 +103,12 @@ class PersonService {
faces: faceIds.toSet(), faces: faceIds.toSet(),
); );
personData.assigned!.add(clusterInfo); personData.assigned!.add(clusterInfo);
await entityService.addOrUpdate( unawaited(
EntityType.person, entityService.addOrUpdate(
json.encode(personData.toJson()), EntityType.person,
id: personID, json.encode(personData.toJson()),
id: personID,
),
); );
await faceMLDataDB.assignClusterToPerson( await faceMLDataDB.assignClusterToPerson(
personID: personID, personID: personID,
@ -190,7 +193,7 @@ class PersonService {
} }
logger.info("Storing feedback for ${faceIdToClusterID.length} faces"); logger.info("Storing feedback for ${faceIdToClusterID.length} faces");
await faceMLDataDB.updateClusterIdToFaceId(faceIdToClusterID); await faceMLDataDB.updateFaceIdToClusterId(faceIdToClusterID);
await faceMLDataDB.bulkAssignClusterToPersonID(clusterToPersonID); await faceMLDataDB.bulkAssignClusterToPersonID(clusterToPersonID);
} }

View File

@ -264,13 +264,56 @@ class _FaceWidgetState extends State<FaceWidget> {
}, },
child: Column( child: Column(
children: [ children: [
SizedBox( Stack(
width: 60, children: [
height: 60, Container(
child: CroppedFaceImgImageView( height: 60,
enteFile: widget.file, width: 60,
face: widget.face, decoration: ShapeDecoration(
), shape: RoundedRectangleBorder(
borderRadius: const BorderRadius.all(
Radius.elliptical(16, 12),
),
side: widget.highlight
? BorderSide(
color: getEnteColorScheme(context).primary700,
width: 1.0,
)
: BorderSide.none,
),
),
child: ClipRRect(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
child: SizedBox(
width: 60,
height: 60,
child: CroppedFaceImgImageView(
enteFile: widget.file,
face: widget.face,
),
),
),
),
// TODO: the edges of the green line are still not properly rounded around ClipRRect
if (widget.editMode)
Positioned(
right: 0,
top: 0,
child: GestureDetector(
onTap: _cornerIconPressed,
child: isJustRemoved
? const Icon(
CupertinoIcons.add_circled_solid,
color: Colors.green,
)
: const Icon(
Icons.cancel,
color: Colors.red,
),
),
),
],
), ),
const SizedBox(height: 8), const SizedBox(height: 8),
if (widget.person != null) if (widget.person != null)

View File

@ -71,9 +71,9 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
]; ];
} }
// Remove faces with low scores and blurry faces // Remove faces with low scores
if (!kDebugMode) { if (!kDebugMode) {
faces.removeWhere((face) => (face.isBlurry || face.score < 0.75)); faces.removeWhere((face) => (face.score < 0.75));
} }
if (faces.isEmpty) { if (faces.isEmpty) {
@ -85,9 +85,6 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
]; ];
} }
// Sort the faces by score in descending order, so that the highest scoring face is first.
faces.sort((Face a, Face b) => b.score.compareTo(a.score));
// TODO: add deduplication of faces of same person // TODO: add deduplication of faces of same person
final faceIdsToClusterIds = await FaceMLDataDB.instance final faceIdsToClusterIds = await FaceMLDataDB.instance
.getFaceIdsToClusterIds(faces.map((face) => face.faceID)); .getFaceIdsToClusterIds(faces.map((face) => face.faceID));
@ -96,6 +93,29 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
final clusterIDToPerson = final clusterIDToPerson =
await FaceMLDataDB.instance.getClusterIDToPersonID(); await FaceMLDataDB.instance.getClusterIDToPersonID();
// Sort faces by name and score
final faceIdToPersonID = <String, String>{};
for (final face in faces) {
final clusterID = faceIdsToClusterIds[face.faceID];
if (clusterID != null) {
final personID = clusterIDToPerson[clusterID];
if (personID != null) {
faceIdToPersonID[face.faceID] = personID;
}
}
}
faces.sort((Face a, Face b) {
final aPersonID = faceIdToPersonID[a.faceID];
final bPersonID = faceIdToPersonID[b.faceID];
if (aPersonID != null && bPersonID == null) {
return -1;
} else if (aPersonID == null && bPersonID != null) {
return 1;
} else {
return b.score.compareTo(a.score);
}
});
final lastViewedClusterID = ClusterFeedbackService.lastViewedClusterID; final lastViewedClusterID = ClusterFeedbackService.lastViewedClusterID;
final faceWidgets = <FaceWidget>[]; final faceWidgets = <FaceWidget>[];

View File

@ -207,14 +207,14 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
if (embedding.key == otherEmbedding.key) { if (embedding.key == otherEmbedding.key) {
continue; continue;
} }
final distance64 = 1.0 - final distance64 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float64).dot( Vector.fromList(embedding.value, dtype: DType.float64),
Vector.fromList(otherEmbedding.value, dtype: DType.float64), Vector.fromList(otherEmbedding.value, dtype: DType.float64),
); );
final distance32 = 1.0 - final distance32 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float32).dot( Vector.fromList(embedding.value, dtype: DType.float32),
Vector.fromList(otherEmbedding.value, dtype: DType.float32), Vector.fromList(otherEmbedding.value, dtype: DType.float32),
); );
final distance = cosineDistForNormVectors( final distance = cosineDistForNormVectors(
embedding.value, embedding.value,
otherEmbedding.value, otherEmbedding.value,

View File

@ -1,3 +1,4 @@
import "dart:async" show StreamSubscription, unawaited;
import "dart:math"; import "dart:math";
import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/foundation.dart" show kDebugMode;
@ -29,16 +30,25 @@ class PersonReviewClusterSuggestion extends StatefulWidget {
class _PersonClustersState extends State<PersonReviewClusterSuggestion> { class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
int currentSuggestionIndex = 0; int currentSuggestionIndex = 0;
bool fetch = true;
Key futureBuilderKey = UniqueKey(); Key futureBuilderKey = UniqueKey();
// Declare a variable for the future // Declare a variable for the future
late Future<List<ClusterSuggestion>> futureClusterSuggestions; late Future<List<ClusterSuggestion>> futureClusterSuggestions;
late StreamSubscription<PeopleChangedEvent> _peopleChangedEvent;
@override @override
void initState() { void initState() {
super.initState(); super.initState();
// Initialize the future in initState // Initialize the future in initState
_fetchClusterSuggestions(); if (fetch) _fetchClusterSuggestions();
fetch = true;
}
@override
void dispose() {
_peopleChangedEvent.cancel();
super.dispose();
} }
@override @override
@ -61,12 +71,27 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
), ),
); );
} }
final numberOfDifferentSuggestions = snapshot.data!.length;
final currentSuggestion = snapshot.data![currentSuggestionIndex]; final allSuggestions = snapshot.data!;
final numberOfDifferentSuggestions = allSuggestions.length;
final currentSuggestion = allSuggestions[currentSuggestionIndex];
final int clusterID = currentSuggestion.clusterIDToMerge; final int clusterID = currentSuggestion.clusterIDToMerge;
final double distance = currentSuggestion.distancePersonToCluster; final double distance = currentSuggestion.distancePersonToCluster;
final bool usingMean = currentSuggestion.usedOnlyMeanForSuggestion; final bool usingMean = currentSuggestion.usedOnlyMeanForSuggestion;
final List<EnteFile> files = currentSuggestion.filesInCluster; final List<EnteFile> files = currentSuggestion.filesInCluster;
_peopleChangedEvent =
Bus.instance.on<PeopleChangedEvent>().listen((event) {
if (event.type == PeopleEventType.removedFilesFromCluster &&
(event.source == clusterID.toString())) {
for (var updatedFile in event.relevantFiles!) {
files.remove(updatedFile);
}
fetch = false;
setState(() {});
}
});
return InkWell( return InkWell(
onTap: () { onTap: () {
Navigator.of(context).push( Navigator.of(context).push(
@ -90,6 +115,7 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
usingMean, usingMean,
files, files,
numberOfDifferentSuggestions, numberOfDifferentSuggestions,
allSuggestions,
), ),
), ),
); );
@ -116,20 +142,25 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
clusterID: clusterID, clusterID: clusterID,
); );
Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire(PeopleChangedEvent());
// Increment the suggestion index
if (mounted) {
setState(() => currentSuggestionIndex++);
}
// Check if we need to fetch new data
if (currentSuggestionIndex >= (numberOfSuggestions)) {
setState(() {
currentSuggestionIndex = 0;
futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder
_fetchClusterSuggestions();
});
}
} else { } else {
await FaceMLDataDB.instance.captureNotPersonFeedback( await FaceMLDataDB.instance.captureNotPersonFeedback(
personID: widget.person.remoteID, personID: widget.person.remoteID,
clusterID: clusterID, clusterID: clusterID,
); );
} // Recalculate the suggestions when a suggestion is rejected
// Increment the suggestion index
if (mounted) {
setState(() => currentSuggestionIndex++);
}
// Check if we need to fetch new data
if (currentSuggestionIndex >= (numberOfSuggestions)) {
setState(() { setState(() {
currentSuggestionIndex = 0; currentSuggestionIndex = 0;
futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder
@ -150,9 +181,10 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
bool usingMean, bool usingMean,
List<EnteFile> files, List<EnteFile> files,
int numberOfSuggestions, int numberOfSuggestions,
List<ClusterSuggestion> allSuggestions,
) { ) {
return Column( final widgetToReturn = Column(
key: ValueKey("cluster_id-$clusterID"), key: ValueKey("cluster_id-$clusterID-files-${files.length}"),
children: <Widget>[ children: <Widget>[
if (kDebugMode) if (kDebugMode)
Text( Text(
@ -228,6 +260,28 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
), ),
], ],
); );
// Precompute face thumbnails for next suggestions, in case there are
const precompute = 6;
const maxComputations = 10;
int compCount = 0;
if (allSuggestions.length > currentSuggestionIndex + 1) {
for (final suggestion in allSuggestions.sublist(
currentSuggestionIndex + 1,
min(allSuggestions.length, currentSuggestionIndex + precompute),
)) {
final files = suggestion.filesInCluster;
final clusterID = suggestion.clusterIDToMerge;
for (final file in files.sublist(0, min(files.length, 8))) {
unawaited(PersonFaceWidget.precomputeFaceCrops(file, clusterID));
compCount++;
if (compCount >= maxComputations) {
break;
}
}
}
}
return widgetToReturn;
} }
List<Widget> _buildThumbnailWidgets( List<Widget> _buildThumbnailWidgets(

View File

@ -33,9 +33,64 @@ class PersonFaceWidget extends StatelessWidget {
), ),
super(key: key); super(key: key);
static Future<void> precomputeFaceCrops(file, clusterID) async {
try {
final Face? face = await FaceMLDataDB.instance.getCoverFaceForPerson(
recentFileID: file.uploadedFileID!,
clusterID: clusterID,
);
if (face == null) {
debugPrint(
"No cover face for cluster $clusterID and recentFile ${file.uploadedFileID}",
);
return;
}
final Uint8List? cachedFace = faceCropCache.get(face.faceID);
if (cachedFace != null) {
return;
}
final faceCropCacheFile = cachedFaceCropPath(face.faceID);
if ((await faceCropCacheFile.exists())) {
final data = await faceCropCacheFile.readAsBytes();
faceCropCache.put(face.faceID, data);
return;
}
EnteFile? fileForFaceCrop = file;
if (face.fileID != file.uploadedFileID!) {
fileForFaceCrop =
await FilesDB.instance.getAnyUploadedFile(face.fileID);
}
if (fileForFaceCrop == null) {
return;
}
final result = await pool.withResource(
() async => await getFaceCrops(
fileForFaceCrop!,
{
face.faceID: face.detection.box,
},
),
);
final Uint8List? computedCrop = result?[face.faceID];
if (computedCrop != null) {
faceCropCache.put(face.faceID, computedCrop);
faceCropCacheFile.writeAsBytes(computedCrop).ignore();
}
return;
} catch (e, s) {
log(
"Error getting cover face for cluster $clusterID",
error: e,
stackTrace: s,
);
return;
}
}
@override @override
Widget build(BuildContext context) { Widget build(BuildContext context) {
if (useGeneratedFaceCrops) { if (!useGeneratedFaceCrops) {
return FutureBuilder<Uint8List?>( return FutureBuilder<Uint8List?>(
future: getFaceCrop(), future: getFaceCrop(),
builder: (context, snapshot) { builder: (context, snapshot) {

View File

@ -11,7 +11,7 @@ import "package:photos/utils/thumbnail_util.dart";
import "package:pool/pool.dart"; import "package:pool/pool.dart";
final LRUMap<String, Uint8List?> faceCropCache = LRUMap(1000); final LRUMap<String, Uint8List?> faceCropCache = LRUMap(1000);
final pool = Pool(5, timeout: const Duration(seconds: 15)); final pool = Pool(10, timeout: const Duration(seconds: 15));
Future<Map<String, Uint8List>?> getFaceCrops( Future<Map<String, Uint8List>?> getFaceCrops(
EnteFile file, EnteFile file,
Map<String, FaceBox> faceBoxeMap, Map<String, FaceBox> faceBoxeMap,