Merge remote-tracking branch 'origin/mobile_face' into mobile_face

This commit is contained in:
laurenspriem 2024-04-02 17:30:50 +05:30
commit 8fefc22180
2 changed files with 112 additions and 30 deletions

View File

@ -181,6 +181,32 @@ class FaceMLDataDB {
return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
}
Future<Map<int, Iterable<Uint8List>>> getFaceEmbeddingsForClusters(
Iterable<int> clusterIDs, {
int? limit,
}) async {
final db = await instance.database;
final Map<int, List<Uint8List>> result = {};
final selectQuery = '''
SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob
FROM $faceClustersTable fc
INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn
WHERE fc.$fcClusterID IN (${clusterIDs.join(',')})
${limit != null ? 'LIMIT $limit' : ''}
''';
final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery);
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceEmbedding = map[faceEmbeddingBlob] as Uint8List;
result.putIfAbsent(clusterID, () => <Uint8List>[]).add(faceEmbedding);
}
return result;
}
Future<Face?> getCoverFaceForPerson({
required int recentFileID,
String? personID,
@ -668,9 +694,11 @@ class FaceMLDataDB {
await db.execute(deletePersonTable);
await db.execute(dropClusterPersonTable);
await db.execute(dropNotPersonFeedbackTable);
await db.execute(dropClusterSummaryTable);
await db.execute(createPersonTable);
await db.execute(createClusterPersonTable);
await db.execute(createNotPersonFeedbackTable);
await db.execute(createClusterSummaryTable);
}
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {

View File

@ -367,11 +367,13 @@ class ClusterFeedbackService {
Future<Map<int, List<double>>> _getUpdateClusterAvg(
Map<int, int> allClusterIdsToCountMap,
Set<int> ignoredClusters,
) async {
Set<int> ignoredClusters, {
int minClusterSize = 1,
int maxClusterInCurrentRun = 500,
}) async {
final faceMlDb = FaceMLDataDB.instance;
_logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
);
final Map<int, (Uint8List, int)> clusterToSummary =
@ -380,41 +382,91 @@ class ClusterFeedbackService {
final Map<int, List<double>> clusterAvg = {};
final allClusterIds = allClusterIdsToCountMap.keys;
for (final clusterID in allClusterIds) {
if (ignoredClusters.contains(clusterID)) {
continue;
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
int smallerClustersCnt = 0;
for (final id in allClusterIdsToCountMap.keys) {
if (ignoredClusters.contains(id)) {
allClusterIds.remove(id);
ignoredClustersCnt++;
}
if (allClusterIdsToCountMap[clusterID]! < 2) {
continue;
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id);
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
alreadyUpdatedClustersCnt++;
}
if (allClusterIdsToCountMap[id]! < minClusterSize) {
allClusterIds.remove(id);
smallerClustersCnt++;
}
}
_logger.info(
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
);
// get clusterIDs sorted by count in descending order
final sortedClusterIDs = allClusterIds.toList();
sortedClusterIDs.sort(
(a, b) =>
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
);
int indexedInCurrentRun = 0;
final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null;
w?.start();
late List<double> avg;
if (clusterToSummary[clusterID]?.$2 ==
allClusterIdsToCountMap[clusterID]) {
avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values;
} else {
final Iterable<Uint8List> embedings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
final List<double> sum = List.filled(192, 0);
for (final embedding in embedings) {
final data = EVector.fromBuffer(embedding).values;
for (int i = 0; i < sum.length; i++) {
sum[i] += data[i];
}
}
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length);
w?.log(
'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
);
final int maxEmbeddingToRead = 10000;
int currentPendingRead = 0;
List<int> clusterIdsToRead = [];
for (final clusterID in sortedClusterIDs) {
if (maxClusterInCurrentRun-- <= 0) {
break;
}
if (currentPendingRead == 0) {
currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0;
clusterIdsToRead.add(clusterID);
} else {
if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) <
maxEmbeddingToRead) {
clusterIdsToRead.add(clusterID);
currentPendingRead += allClusterIdsToCountMap[clusterID]!;
} else {
break;
}
}
}
final Map<int, Iterable<Uint8List>> clusterEmbeddings = await FaceMLDataDB
.instance
.getFaceEmbeddingsForClusters(clusterIdsToRead);
w?.logAndReset(
'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
);
for (final clusterID in clusterEmbeddings.keys) {
late List<double> avg;
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
final List<double> sum = List.filled(192, 0);
for (final embedding in embedings) {
final data = EVector.fromBuffer(embedding).values;
for (int i = 0; i < sum.length; i++) {
sum[i] += data[i];
}
}
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length);
// store the intermediate updates
indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
updatesForClusterSummary.clear();
if (kDebugMode) {
_logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
'getUpdateClusterAvg $indexedInCurrentRun clusters in current one',
);
}
}
@ -423,6 +475,7 @@ class ClusterFeedbackService {
if (updatesForClusterSummary.isNotEmpty) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
}
w?.logAndReset('done computing avg ');
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
return clusterAvg;
@ -549,8 +602,9 @@ class ClusterFeedbackService {
);
}
suggestion.$4.sort((b, a) {
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!];
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!];
//todo: review with @laurens, added this to avoid null safety issue
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1;
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1;
return distanceA.compareTo(distanceB);
});