This commit is contained in:
Neeraj Gupta 2025-01-29 15:33:46 +05:30
parent 29f4bbb0de
commit 09b88e5bab
4 changed files with 56 additions and 57 deletions

View File

@ -39,7 +39,7 @@ class ClusterFeedbackService<T> {
final Logger _logger = Logger("ClusterFeedbackService");
final _computer = Computer.shared();
ClusterFeedbackService._privateConstructor();
late final faceMLDB = MLDataDB.instance;
late final mlDataDB = MLDataDB.instance;
static final ClusterFeedbackService instance =
ClusterFeedbackService._privateConstructor();
@ -79,11 +79,11 @@ class ClusterFeedbackService<T> {
// Get the files for the suggestions
final suggestionClusterIDs = foundSuggestions.map((e) => e.$1).toSet();
final Map<int, Set<String>> fileIdToClusterID =
await faceMLDB.getFileIdToClusterIDSetForCluster(
await mlDataDB.getFileIdToClusterIDSetForCluster(
suggestionClusterIDs,
);
final clusterIdToFaceIDs =
await faceMLDB.getClusterToFaceIDs(suggestionClusterIDs);
await mlDataDB.getClusterToFaceIDs(suggestionClusterIDs);
final Map<String, List<EnteFile>> clusterIDToFiles = {};
final allFiles = await SearchService.instance.getAllFilesForSearch();
for (final f in allFiles) {
@ -142,14 +142,14 @@ class ClusterFeedbackService<T> {
try {
_logger.info('removeFilesFromPerson called');
// Get the relevant faces to be removed
final faceIDs = await faceMLDB
final faceIDs = await mlDataDB
.getFaceIDsForPerson(p.remoteID)
.then((iterable) => iterable.toList());
faceIDs.retainWhere((faceID) {
final fileID = getFileIdFromFaceId<int>(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings = await faceMLDB.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings = await mlDataDB.getFaceEmbeddingMapForFaces(faceIDs);
if (faceIDs.isEmpty || embeddings.isEmpty) {
_logger.severe(
@ -175,15 +175,15 @@ class ClusterFeedbackService<T> {
final newFaceIdToClusterID = clusterResult.newFaceIdToCluster;
// Update the deleted faces
await faceMLDB.forceUpdateClusterIds(newFaceIdToClusterID);
await faceMLDB.clusterSummaryUpdate(clusterResult.newClusterSummaries);
await mlDataDB.forceUpdateClusterIds(newFaceIdToClusterID);
await mlDataDB.clusterSummaryUpdate(clusterResult.newClusterSummaries);
// Make sure the deleted faces don't get suggested in the future
final notClusterIdToPersonId = <String, String>{};
for (final clusterId in newFaceIdToClusterID.values.toSet()) {
notClusterIdToPersonId[clusterId] = p.remoteID;
}
await faceMLDB.bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
await mlDataDB.bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
// Update remote so new sync does not undo this change
await PersonService.instance
@ -205,14 +205,14 @@ class ClusterFeedbackService<T> {
_logger.info('removeFilesFromCluster called');
try {
// Get the relevant faces to be removed
final faceIDs = await faceMLDB
final faceIDs = await mlDataDB
.getFaceIDsForCluster(clusterID)
.then((iterable) => iterable.toList());
faceIDs.retainWhere((faceID) {
final fileID = getFileIdFromFaceId<int>(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings = await faceMLDB.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings = await mlDataDB.getFaceEmbeddingMapForFaces(faceIDs);
if (faceIDs.isEmpty || embeddings.isEmpty) {
_logger.severe(
@ -238,8 +238,8 @@ class ClusterFeedbackService<T> {
final newFaceIdToClusterID = clusterResult.newFaceIdToCluster;
// Update the deleted faces
await faceMLDB.forceUpdateClusterIds(newFaceIdToClusterID);
await faceMLDB.clusterSummaryUpdate(clusterResult.newClusterSummaries);
await mlDataDB.forceUpdateClusterIds(newFaceIdToClusterID);
await mlDataDB.clusterSummaryUpdate(clusterResult.newClusterSummaries);
Bus.instance.fire(
PeopleChangedEvent(
@ -261,7 +261,7 @@ class ClusterFeedbackService<T> {
for (final faceID in faceIDs) {
faceIDToClusterID[faceID] = clusterID;
}
await faceMLDB.forceUpdateClusterIds(faceIDToClusterID);
await mlDataDB.forceUpdateClusterIds(faceIDToClusterID);
Bus.instance.fire(PeopleChangedEvent());
return;
}
@ -270,8 +270,8 @@ class ClusterFeedbackService<T> {
PersonEntity p, {
required String personClusterID,
}) async {
final faceIDs = await faceMLDB.getFaceIDsForCluster(personClusterID);
final ignoredClusters = await faceMLDB.getPersonIgnoredClusters(p.remoteID);
final faceIDs = await mlDataDB.getFaceIDsForCluster(personClusterID);
final ignoredClusters = await mlDataDB.getPersonIgnoredClusters(p.remoteID);
if (faceIDs.length < 2 * kMinimumClusterSizeSearchResult) {
final fileIDs = faceIDs.map(getFileIdFromFaceId<int>).toSet();
if (fileIDs.length < kMinimumClusterSizeSearchResult) {
@ -281,7 +281,7 @@ class ClusterFeedbackService<T> {
return false;
}
}
final allClusterIdsToCountMap = (await faceMLDB.clusterIdToFaceCount());
final allClusterIdsToCountMap = (await mlDataDB.clusterIdToFaceCount());
_logger.info(
'${kDebugMode ? p.data.name : "private"} has existing clusterID $personClusterID, checking if we can automatically merge more',
);
@ -318,7 +318,7 @@ class ClusterFeedbackService<T> {
for (final suggestion in suggestions) {
final clusterID = suggestion.$1;
await faceMLDB.assignClusterToPerson(
await mlDataDB.assignClusterToPerson(
personID: p.remoteID,
clusterID: clusterID,
);
@ -334,8 +334,7 @@ class ClusterFeedbackService<T> {
required String clusterID,
}) async {
if (person.data.rejectedFaceIDs.isNotEmpty) {
final clusterFaceIDs =
await MLDataDB.instance.getFaceIDsForCluster(clusterID);
final clusterFaceIDs = await mlDataDB.getFaceIDsForCluster(clusterID);
final rejectedLengthBefore = person.data.rejectedFaceIDs.length;
person.data.rejectedFaceIDs
.removeWhere((faceID) => clusterFaceIDs.contains(faceID));
@ -347,7 +346,7 @@ class ClusterFeedbackService<T> {
await PersonService.instance.updatePerson(person);
}
}
await faceMLDB.assignClusterToPerson(
await mlDataDB.assignClusterToPerson(
personID: person.remoteID,
clusterID: clusterID,
);
@ -362,7 +361,7 @@ class ClusterFeedbackService<T> {
}
Future<List<(String, int)>> checkForMixedClusters() async {
final allClusterToFaceCount = await faceMLDB.clusterIdToFaceCount();
final allClusterToFaceCount = await mlDataDB.clusterIdToFaceCount();
final clustersToInspect = <String>[];
for (final clusterID in allClusterToFaceCount.keys) {
if (allClusterToFaceCount[clusterID]! > 20 &&
@ -379,9 +378,9 @@ class ClusterFeedbackService<T> {
final inspectionStart = DateTime.now();
for (final clusterID in clustersToInspect) {
final int originalClusterSize = allClusterToFaceCount[clusterID]!;
final faceIDs = await faceMLDB.getFaceIDsForCluster(clusterID);
final faceIDs = await mlDataDB.getFaceIDsForCluster(clusterID);
final embeddings = await faceMLDB.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings = await mlDataDB.getFaceEmbeddingMapForFaces(faceIDs);
final clusterResult =
await FaceClusteringService.instance.predictWithinClusterComputer(
@ -425,7 +424,7 @@ class ClusterFeedbackService<T> {
if (biggestRatio < 0.5 || secondBiggestRatio > 0.2) {
final faceIdsOfCluster =
await faceMLDB.getFaceIDsForCluster(clusterID);
await mlDataDB.getFaceIDsForCluster(clusterID);
final uniqueFileIDs =
faceIdsOfCluster.map(getFileIdFromFaceId<int>).toSet();
susClusters.add((clusterID, uniqueFileIDs.length));
@ -459,10 +458,10 @@ class ClusterFeedbackService<T> {
_logger.info(
'breakUpCluster called for cluster $clusterID with dbscan $useDbscan',
);
final faceIDs = await faceMLDB.getFaceIDsForCluster(clusterID);
final faceIDs = await mlDataDB.getFaceIDsForCluster(clusterID);
final originalFaceIDsSet = faceIDs.toSet();
final embeddings = await faceMLDB.getFaceEmbeddingMapForFaces(faceIDs);
final embeddings = await mlDataDB.getFaceEmbeddingMapForFaces(faceIDs);
if (embeddings.isEmpty) {
_logger.warning('No embeddings found for cluster $clusterID');
@ -520,15 +519,15 @@ class ClusterFeedbackService<T> {
}) async {
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
// Get all the cluster data
final allClusterIdsToCountMap = await faceMLDB.clusterIdToFaceCount();
final ignoredClusters = await faceMLDB.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMLDB.getPersonClusterIDs(p.remoteID);
final personFaceIDs = await faceMLDB.getFaceIDsForPerson(p.remoteID);
final allClusterIdsToCountMap = await mlDataDB.clusterIdToFaceCount();
final ignoredClusters = await mlDataDB.getPersonIgnoredClusters(p.remoteID);
final personClusters = await mlDataDB.getPersonClusterIDs(p.remoteID);
final personFaceIDs = await mlDataDB.getFaceIDsForPerson(p.remoteID);
final personFileIDs = personFaceIDs.map(getFileIdFromFaceId<int>).toSet();
w?.log(
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data done',
);
final allClusterIdToFaceIDs = await faceMLDB.getAllClusterIdToFaceIDs();
final allClusterIdToFaceIDs = await mlDataDB.getAllClusterIdToFaceIDs();
w?.log('getAllClusterIdToFaceIDs done');
// First only do a simple check on the big clusters, if the person does not have small clusters yet
@ -567,7 +566,7 @@ class ClusterFeedbackService<T> {
final overlap = personFileIDs.intersection(suggestionSet);
if (overlap.isNotEmpty &&
((overlap.length / suggestionSet.length) > 0.5)) {
await faceMLDB.captureNotPersonFeedback(
await mlDataDB.captureNotPersonFeedback(
personID: p.remoteID,
clusterID: suggestion.$1,
);
@ -614,7 +613,7 @@ class ClusterFeedbackService<T> {
final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) {
final Iterable<Uint8List> embeddings =
await faceMLDB.getFaceEmbeddingsForCluster(clusterID);
await mlDataDB.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embeddings);
}
final List<Uint8List> sampledEmbeddingsProto =
@ -637,7 +636,7 @@ class ClusterFeedbackService<T> {
double minMedianDistance = maxMedianDistance;
for (final otherClusterId in otherClusterIdsCandidates) {
final Iterable<Uint8List> otherEmbeddingsProto =
await faceMLDB.getFaceEmbeddingsForCluster(
await mlDataDB.getFaceEmbeddingsForCluster(
otherClusterId,
);
final sampledOtherEmbeddingsProto = _randomSampleWithoutReplacement(
@ -718,7 +717,7 @@ class ClusterFeedbackService<T> {
);
final Map<String, (Uint8List, int)> clusterToSummary =
await faceMLDB.getAllClusterSummary(minClusterSize);
await mlDataDB.getAllClusterSummary(minClusterSize);
final Map<String, (Uint8List, int)> updatesForClusterSummary = {};
w?.log(
@ -789,7 +788,7 @@ class ClusterFeedbackService<T> {
}
final Map<String, Iterable<Uint8List>> clusterEmbeddings =
await faceMLDB.getFaceEmbeddingsForClusters(clusterIdsToRead);
await mlDataDB.getFaceEmbeddingsForClusters(clusterIdsToRead);
w?.logAndReset(
'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
@ -811,7 +810,7 @@ class ClusterFeedbackService<T> {
// store the intermediate updates
indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) {
await faceMLDB.clusterSummaryUpdate(updatesForClusterSummary);
await mlDataDB.clusterSummaryUpdate(updatesForClusterSummary);
updatesForClusterSummary.clear();
if (kDebugMode) {
_logger.info(
@ -822,7 +821,7 @@ class ClusterFeedbackService<T> {
clusterAvg[clusterID] = avgNormalized;
}
if (updatesForClusterSummary.isNotEmpty) {
await faceMLDB.clusterSummaryUpdate(updatesForClusterSummary);
await mlDataDB.clusterSummaryUpdate(updatesForClusterSummary);
}
w?.logAndReset('done computing avg ');
_logger.info(
@ -902,9 +901,9 @@ class ClusterFeedbackService<T> {
final startTime = DateTime.now();
// Get the cluster averages for the person's clusters and the suggestions' clusters
final personClusters = await faceMLDB.getPersonClusterIDs(person.remoteID);
final personClusters = await mlDataDB.getPersonClusterIDs(person.remoteID);
final Map<String, (Uint8List, int)> personClusterToSummary =
await faceMLDB.getClusterToClusterSummary(personClusters);
await mlDataDB.getClusterToClusterSummary(personClusters);
final clusterSummaryCallTime = DateTime.now();
// remove personClusters that don't have any summary
@ -948,7 +947,7 @@ class ClusterFeedbackService<T> {
}
final clusterID = suggestion.clusterIDToMerge;
final faceIDs = suggestion.faceIDsInCluster;
final faceIdToEmbeddingMap = await faceMLDB.getFaceEmbeddingMapForFaces(
final faceIdToEmbeddingMap = await mlDataDB.getFaceEmbeddingMapForFaces(
faceIDs,
);
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
@ -1005,7 +1004,7 @@ class ClusterFeedbackService<T> {
// Logging the cluster summary for the cluster
if (logClusterSummary) {
final summaryMap = await faceMLDB.getClusterToClusterSummary(
final summaryMap = await mlDataDB.getClusterToClusterSummary(
[clusterID, biggestClusterID],
);
final summary = summaryMap[clusterID];
@ -1052,7 +1051,7 @@ class ClusterFeedbackService<T> {
// Median distance
const sampleSize = 100;
final Iterable<Uint8List> biggestEmbeddings =
await faceMLDB.getFaceEmbeddingsForCluster(biggestClusterID);
await mlDataDB.getFaceEmbeddingsForCluster(biggestClusterID);
final List<Uint8List> biggestSampledEmbeddingsProto =
_randomSampleWithoutReplacement(
biggestEmbeddings,
@ -1069,7 +1068,7 @@ class ClusterFeedbackService<T> {
.toList(growable: false);
final Iterable<Uint8List> currentEmbeddings =
await faceMLDB.getFaceEmbeddingsForCluster(clusterID);
await mlDataDB.getFaceEmbeddingsForCluster(clusterID);
final List<Uint8List> currentSampledEmbeddingsProto =
_randomSampleWithoutReplacement(
currentEmbeddings,
@ -1115,7 +1114,7 @@ class ClusterFeedbackService<T> {
// Logging the blur values for the cluster
if (logBlurValues) {
final List<double> blurValues = await faceMLDB
final List<double> blurValues = await mlDataDB
.getBlurValuesForCluster(clusterID)
.then((value) => value.toList());
final blurValuesIntegers =

View File

@ -61,6 +61,7 @@ class SearchService {
final _logger = Logger((SearchService).toString());
final _collectionService = CollectionsService.instance;
static const _maximumResultsLimit = 20;
late final mlDataDB = MLDataDB.instance;
SearchService._privateConstructor();
@ -764,7 +765,7 @@ class SearchService {
) async {
_logger.info('getClusterFilesForPersonID $personID');
final Map<int, Set<String>> fileIdToClusterID =
await MLDataDB.instance.getFileIdToClusterIDSet(personID);
await mlDataDB.getFileIdToClusterIDSet(personID);
_logger.info('faceDbDone getClusterFilesForPersonID $personID');
final Map<String, List<EnteFile>> clusterIDToFiles = {};
final allFiles = await getAllFilesForSearch();
@ -792,11 +793,10 @@ class SearchService {
try {
debugPrint("getting faces");
final Map<int, Set<String>> fileIdToClusterID =
await MLDataDB.instance.getFileIdToClusterIds();
await mlDataDB.getFileIdToClusterIds();
final Map<String, PersonEntity> personIdToPerson =
await PersonService.instance.getPersonsMap();
final clusterIDToPersonID =
await MLDataDB.instance.getClusterIDToPersonID();
final clusterIDToPersonID = await mlDataDB.getClusterIDToPersonID();
final List<GenericSearchResult> facesResult = [];
final Map<String, List<EnteFile>> clusterIdToFiles = {};
@ -920,7 +920,7 @@ class SearchService {
"`getAllFace`: Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}, deleting the mapping",
Exception('ClusterID assigned to a person that no longer exists'),
);
await MLDataDB.instance.removeClusterToPerson(
await mlDataDB.removeClusterToPerson(
personID: personID,
clusterID: clusterId,
);

View File

@ -49,6 +49,7 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
List<SuggestionUserFeedback> pastUserFeedback = [];
List<ClusterSuggestion> allSuggestions = [];
late final Logger _logger = Logger('_PersonClustersState');
late final mlDataDB = MLDataDB.instance;
// Declare a variable for the future
late Future<List<ClusterSuggestion>> futureClusterSuggestions;
@ -283,7 +284,7 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
int numberOfSuggestions,
) async {
canGiveFeedback = false;
await MLDataDB.instance.captureNotPersonFeedback(
await mlDataDB.captureNotPersonFeedback(
personID: widget.person.remoteID,
clusterID: clusterID,
);
@ -504,7 +505,7 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
clusterID: lastFeedback.suggestion.clusterIDToMerge,
);
} else {
await MLDataDB.instance.removeNotPersonFeedback(
await mlDataDB.removeNotPersonFeedback(
personID: widget.person.remoteID,
clusterID: lastFeedback.suggestion.clusterIDToMerge,
);

View File

@ -46,6 +46,7 @@ class PersonFaceWidget extends StatefulWidget {
class _PersonFaceWidgetState extends State<PersonFaceWidget> {
Future<Uint8List?>? faceCropFuture;
late final mlDataDB = MLDataDB.instance;
@override
void initState() {
@ -101,13 +102,11 @@ class _PersonFaceWidgetState extends State<PersonFaceWidget> {
if (tryCache != null) return tryCache;
}
if (personAvatarFaceID == null && widget.cannotTrustFile) {
allFaces =
await MLDataDB.instance.getFaceIDsForPerson(widget.personId!);
allFaces = await mlDataDB.getFaceIDsForPerson(widget.personId!);
}
}
} else if (widget.clusterID != null && widget.cannotTrustFile) {
allFaces =
await MLDataDB.instance.getFaceIDsForCluster(widget.clusterID!);
allFaces = await mlDataDB.getFaceIDsForCluster(widget.clusterID!);
}
if (allFaces != null) {
final allFileIDs =
@ -137,7 +136,7 @@ class _PersonFaceWidgetState extends State<PersonFaceWidget> {
}
}
final Face? face = await MLDataDB.instance.getCoverFaceForPerson(
final Face? face = await mlDataDB.getCoverFaceForPerson(
recentFileID: fileForFaceCrop.uploadedFileID!,
avatarFaceId: personAvatarFaceID,
personID: widget.personId,