mirror of
https://github.com/ente-io/ente.git
synced 2025-08-13 01:27:17 +00:00
Merge remote-tracking branch 'origin/mobile_face' into mobile_face
This commit is contained in:
@@ -181,6 +181,32 @@ class FaceMLDataDB {
|
|||||||
return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
|
return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<Map<int, Iterable<Uint8List>>> getFaceEmbeddingsForClusters(
|
||||||
|
Iterable<int> clusterIDs, {
|
||||||
|
int? limit,
|
||||||
|
}) async {
|
||||||
|
final db = await instance.database;
|
||||||
|
final Map<int, List<Uint8List>> result = {};
|
||||||
|
|
||||||
|
final selectQuery = '''
|
||||||
|
SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob
|
||||||
|
FROM $faceClustersTable fc
|
||||||
|
INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn
|
||||||
|
WHERE fc.$fcClusterID IN (${clusterIDs.join(',')})
|
||||||
|
${limit != null ? 'LIMIT $limit' : ''}
|
||||||
|
''';
|
||||||
|
|
||||||
|
final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery);
|
||||||
|
|
||||||
|
for (final map in maps) {
|
||||||
|
final clusterID = map[fcClusterID] as int;
|
||||||
|
final faceEmbedding = map[faceEmbeddingBlob] as Uint8List;
|
||||||
|
result.putIfAbsent(clusterID, () => <Uint8List>[]).add(faceEmbedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
Future<Face?> getCoverFaceForPerson({
|
Future<Face?> getCoverFaceForPerson({
|
||||||
required int recentFileID,
|
required int recentFileID,
|
||||||
String? personID,
|
String? personID,
|
||||||
@@ -668,9 +694,11 @@ class FaceMLDataDB {
|
|||||||
await db.execute(deletePersonTable);
|
await db.execute(deletePersonTable);
|
||||||
await db.execute(dropClusterPersonTable);
|
await db.execute(dropClusterPersonTable);
|
||||||
await db.execute(dropNotPersonFeedbackTable);
|
await db.execute(dropNotPersonFeedbackTable);
|
||||||
|
await db.execute(dropClusterSummaryTable);
|
||||||
await db.execute(createPersonTable);
|
await db.execute(createPersonTable);
|
||||||
await db.execute(createClusterPersonTable);
|
await db.execute(createClusterPersonTable);
|
||||||
await db.execute(createNotPersonFeedbackTable);
|
await db.execute(createNotPersonFeedbackTable);
|
||||||
|
await db.execute(createClusterSummaryTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {
|
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {
|
||||||
|
@@ -367,11 +367,13 @@ class ClusterFeedbackService {
|
|||||||
|
|
||||||
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
||||||
Map<int, int> allClusterIdsToCountMap,
|
Map<int, int> allClusterIdsToCountMap,
|
||||||
Set<int> ignoredClusters,
|
Set<int> ignoredClusters, {
|
||||||
) async {
|
int minClusterSize = 1,
|
||||||
|
int maxClusterInCurrentRun = 500,
|
||||||
|
}) async {
|
||||||
final faceMlDb = FaceMLDataDB.instance;
|
final faceMlDb = FaceMLDataDB.instance;
|
||||||
_logger.info(
|
_logger.info(
|
||||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
|
||||||
);
|
);
|
||||||
|
|
||||||
final Map<int, (Uint8List, int)> clusterToSummary =
|
final Map<int, (Uint8List, int)> clusterToSummary =
|
||||||
@@ -380,22 +382,72 @@ class ClusterFeedbackService {
|
|||||||
|
|
||||||
final Map<int, List<double>> clusterAvg = {};
|
final Map<int, List<double>> clusterAvg = {};
|
||||||
|
|
||||||
final allClusterIds = allClusterIdsToCountMap.keys;
|
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
|
||||||
for (final clusterID in allClusterIds) {
|
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
|
||||||
if (ignoredClusters.contains(clusterID)) {
|
int smallerClustersCnt = 0;
|
||||||
continue;
|
for (final id in allClusterIdsToCountMap.keys) {
|
||||||
|
if (ignoredClusters.contains(id)) {
|
||||||
|
allClusterIds.remove(id);
|
||||||
|
ignoredClustersCnt++;
|
||||||
|
}
|
||||||
|
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||||
|
allClusterIds.remove(id);
|
||||||
|
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
||||||
|
alreadyUpdatedClustersCnt++;
|
||||||
|
}
|
||||||
|
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||||
|
allClusterIds.remove(id);
|
||||||
|
smallerClustersCnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_logger.info(
|
||||||
|
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
|
||||||
|
);
|
||||||
|
// get clusterIDs sorted by count in descending order
|
||||||
|
final sortedClusterIDs = allClusterIds.toList();
|
||||||
|
sortedClusterIDs.sort(
|
||||||
|
(a, b) =>
|
||||||
|
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
|
||||||
|
);
|
||||||
|
int indexedInCurrentRun = 0;
|
||||||
|
final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null;
|
||||||
|
w?.start();
|
||||||
|
|
||||||
|
w?.log(
|
||||||
|
'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
|
||||||
|
);
|
||||||
|
final int maxEmbeddingToRead = 10000;
|
||||||
|
int currentPendingRead = 0;
|
||||||
|
List<int> clusterIdsToRead = [];
|
||||||
|
for (final clusterID in sortedClusterIDs) {
|
||||||
|
if (maxClusterInCurrentRun-- <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (currentPendingRead == 0) {
|
||||||
|
currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0;
|
||||||
|
clusterIdsToRead.add(clusterID);
|
||||||
|
} else {
|
||||||
|
if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) <
|
||||||
|
maxEmbeddingToRead) {
|
||||||
|
clusterIdsToRead.add(clusterID);
|
||||||
|
currentPendingRead += allClusterIdsToCountMap[clusterID]!;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (allClusterIdsToCountMap[clusterID]! < 2) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final Map<int, Iterable<Uint8List>> clusterEmbeddings = await FaceMLDataDB
|
||||||
|
.instance
|
||||||
|
.getFaceEmbeddingsForClusters(clusterIdsToRead);
|
||||||
|
|
||||||
|
w?.logAndReset(
|
||||||
|
'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
|
||||||
|
);
|
||||||
|
|
||||||
|
for (final clusterID in clusterEmbeddings.keys) {
|
||||||
late List<double> avg;
|
late List<double> avg;
|
||||||
if (clusterToSummary[clusterID]?.$2 ==
|
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
||||||
allClusterIdsToCountMap[clusterID]) {
|
|
||||||
avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values;
|
|
||||||
} else {
|
|
||||||
final Iterable<Uint8List> embedings =
|
|
||||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
|
||||||
final List<double> sum = List.filled(192, 0);
|
final List<double> sum = List.filled(192, 0);
|
||||||
for (final embedding in embedings) {
|
for (final embedding in embedings) {
|
||||||
final data = EVector.fromBuffer(embedding).values;
|
final data = EVector.fromBuffer(embedding).values;
|
||||||
@@ -407,14 +459,14 @@ class ClusterFeedbackService {
|
|||||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
||||||
updatesForClusterSummary[clusterID] =
|
updatesForClusterSummary[clusterID] =
|
||||||
(avgEmbeedingBuffer, embedings.length);
|
(avgEmbeedingBuffer, embedings.length);
|
||||||
}
|
|
||||||
// store the intermediate updates
|
// store the intermediate updates
|
||||||
|
indexedInCurrentRun++;
|
||||||
if (updatesForClusterSummary.length > 100) {
|
if (updatesForClusterSummary.length > 100) {
|
||||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||||
updatesForClusterSummary.clear();
|
updatesForClusterSummary.clear();
|
||||||
if (kDebugMode) {
|
if (kDebugMode) {
|
||||||
_logger.info(
|
_logger.info(
|
||||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
'getUpdateClusterAvg $indexedInCurrentRun clusters in current one',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -423,6 +475,7 @@ class ClusterFeedbackService {
|
|||||||
if (updatesForClusterSummary.isNotEmpty) {
|
if (updatesForClusterSummary.isNotEmpty) {
|
||||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||||
}
|
}
|
||||||
|
w?.logAndReset('done computing avg ');
|
||||||
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
|
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
|
||||||
|
|
||||||
return clusterAvg;
|
return clusterAvg;
|
||||||
@@ -549,8 +602,9 @@ class ClusterFeedbackService {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
suggestion.$4.sort((b, a) {
|
suggestion.$4.sort((b, a) {
|
||||||
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!];
|
//todo: review with @laurens, added this to avoid null safety issue
|
||||||
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!];
|
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1;
|
||||||
|
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1;
|
||||||
return distanceA.compareTo(distanceB);
|
return distanceA.compareTo(distanceB);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user