Ml sync fix (#4027)

## Description

Remotely safe faceIDs of a certain person that are rejected by the user.
This commit is contained in:
Laurens Priem 2024-11-19 13:43:45 +05:30 committed by GitHub
commit 8b07db8a73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 183 additions and 26 deletions

View File

@ -17,14 +17,17 @@ import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/utils/ml_util.dart";
import 'package:sqlite_async/sqlite_async.dart';
/// Stores all data for the FacesML-related features. The database can be accessed by `MLDataDB.instance.database`.
/// Stores all data for the ML related features. The database can be accessed by `MLDataDB.instance.database`.
///
/// This includes:
/// [facesTable] - Stores all the detected faces and its embeddings in the images.
/// [createFaceClustersTable] - Stores all the mappings from the faces (faceID) to the clusters (clusterID).
/// [faceClustersTable] - Stores all the mappings from the faces (faceID) to the clusters (clusterID).
/// [clusterPersonTable] - Stores all the clusters that are mapped to a certain person.
/// [clusterSummaryTable] - Stores a summary of each cluster, containg the mean embedding and the number of faces in the cluster.
/// [notPersonFeedback] - Stores the clusters that are confirmed not to belong to a certain person by the user
///
/// [clipTable] - Stores the embeddings of the CLIP model
/// [fileDataTable] - Stores data about the files that are already processed by the ML models
class MLDataDB {
static final Logger _logger = Logger("MLDataDB");
@ -477,6 +480,25 @@ class MLDataDB {
return result;
}
Future<Map<String, Set<String>>> getClusterIdToFaceIdsForPerson(
String personID,
) async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceClustersTable.$clusterIDColumn, $faceIDColumn FROM $clusterPersonTable '
'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn '
'WHERE $personIdColumn = ?',
[personID],
);
final Map<String, Set<String>> result = {};
for (final map in maps) {
final clusterID = map[clusterIDColumn] as String;
final faceID = map[faceIDColumn] as String;
result.putIfAbsent(clusterID, () => {}).add(faceID);
}
return result;
}
Future<Set<String>> getFaceIDsForPerson(String personID) async {
final db = await instance.asyncDB;
final faceIdsResult = await db.getAll(
@ -553,6 +575,19 @@ class MLDataDB {
await db.executeBatch(sql, parameterSets);
}
Future<void> removeFaceIdToClusterId(
Map<String, String> faceIDToClusterID,
) async {
final db = await instance.asyncDB;
const String sql = '''
DELETE FROM $faceClustersTable
WHERE $faceIDColumn = ? AND $clusterIDColumn = ?
''';
final parameterSets =
faceIDToClusterID.entries.map((e) => [e.key, e.value]).toList();
await db.executeBatch(sql, parameterSets);
}
Future<void> removePerson(String personID) async {
final db = await instance.asyncDB;

View File

@ -51,7 +51,7 @@ class PersonData {
final bool isHidden;
String? avatarFaceID;
List<ClusterInfo>? assigned = List<ClusterInfo>.empty();
List<ClusterInfo>? rejected = List<ClusterInfo>.empty();
List<String>? rejectedFaceIDs = List<String>.empty();
final String? birthDate;
bool hasAvatar() => avatarFaceID != null;
@ -62,7 +62,7 @@ class PersonData {
PersonData({
required this.name,
this.assigned,
this.rejected,
this.rejectedFaceIDs,
this.avatarFaceID,
this.isHidden = false,
this.birthDate,
@ -79,7 +79,7 @@ class PersonData {
return PersonData(
name: name ?? this.name,
assigned: assigned ?? this.assigned,
avatarFaceID: avatarFaceId ?? this.avatarFaceID,
avatarFaceID: avatarFaceId ?? avatarFaceID,
isHidden: isHidden ?? this.isHidden,
birthDate: birthDate ?? this.birthDate,
);
@ -95,7 +95,7 @@ class PersonData {
assignedCount += a.faces.length;
}
sb.writeln('Assigned: ${assigned?.length} withFaces $assignedCount');
sb.writeln('Rejected: ${rejected?.length}');
sb.writeln('Rejected faceIDs: ${rejectedFaceIDs?.length}');
if (assigned != null) {
for (var cluster in assigned!) {
sb.writeln('Cluster: ${cluster.id} - ${cluster.faces.length}');
@ -108,7 +108,7 @@ class PersonData {
Map<String, dynamic> toJson() => {
'name': name,
'assigned': assigned?.map((e) => e.toJson()).toList(),
'rejected': rejected?.map((e) => e.toJson()).toList(),
'rejectedFaceIDs': rejectedFaceIDs,
'avatarFaceID': avatarFaceID,
'isHidden': isHidden,
'birthDate': birthDate,
@ -122,15 +122,16 @@ class PersonData {
json['assigned'].map((x) => ClusterInfo.fromJson(x)),
);
final rejected = (json['rejected'] == null || json['rejected'].length == 0)
? <ClusterInfo>[]
: List<ClusterInfo>.from(
json['rejected'].map((x) => ClusterInfo.fromJson(x)),
);
final List<String> rejectedFaceIDs =
(json['rejectedFaceIDs'] == null || json['rejectedFaceIDs'].length == 0)
? <String>[]
: List<String>.from(
json['rejectedFaceIDs'],
);
return PersonData(
name: json['name'] as String,
assigned: assigned,
rejected: rejected,
rejectedFaceIDs: rejectedFaceIDs,
avatarFaceID: json['avatarFaceID'] as String?,
isHidden: json['isHidden'] as bool? ?? false,
birthDate: json['birthDate'] as String?,

View File

@ -22,6 +22,7 @@ class FaceInfo {
final bool? badFace;
final Vector? vEmbedding;
String? clusterId;
final List<String>? rejectedClusterIds;
String? closestFaceId;
int? closestDist;
int? fileCreationTime;
@ -32,6 +33,7 @@ class FaceInfo {
this.badFace,
this.vEmbedding,
this.clusterId,
this.rejectedClusterIds,
this.fileCreationTime,
});
}
@ -161,7 +163,7 @@ class FaceClusteringService extends SuperIsolate {
_logger.info(
'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold',
);
final ClusteringResult clusterResult = await predictCompleteComputer(
final ClusteringResult clusterResult = await _predictCompleteComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
@ -173,7 +175,7 @@ class FaceClusteringService extends SuperIsolate {
_logger.info(
'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
);
final ClusteringResult clusterResult = await predictLinearComputer(
final ClusteringResult clusterResult = await _predictLinearComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
@ -188,7 +190,7 @@ class FaceClusteringService extends SuperIsolate {
}
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
Future<ClusteringResult> predictLinearComputer(
Future<ClusteringResult> _predictLinearComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
required Map<String, (Uint8List, int)> oldClusterSummaries,
@ -248,7 +250,7 @@ class FaceClusteringService extends SuperIsolate {
/// Runs the clustering algorithm [_runCompleteClustering] on the given [input], in computer.
///
/// WARNING: Only use on small datasets, as it is not optimized for large datasets.
Future<ClusteringResult> predictCompleteComputer(
Future<ClusteringResult> _predictCompleteComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
required Map<String, (Uint8List, int)> oldClusterSummaries,
@ -328,6 +330,7 @@ ClusteringResult runLinearClustering(Map args) {
dtype: DType.float32,
),
clusterId: face.clusterId,
rejectedClusterIds: face.rejectedClusterIds,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
),
@ -372,7 +375,6 @@ ClusteringResult runLinearClustering(Map args) {
_logger.info(
"[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces ($newToClusterCount new, $alreadyClusteredCount already done) in total in this round ${offset != null ? "on top of ${offset + facesWithClusterID.length} earlier processed faces" : ""}",
);
// set current epoch time as clusterID
String clusterID = newClusterID();
if (facesWithClusterID.isEmpty) {
// assign a clusterID to the first face
@ -398,6 +400,7 @@ ClusteringResult runLinearClustering(Map args) {
} else {
thresholdValue = distanceThreshold;
}
final bool faceHasBeenRejectedBefore = sortedFaceInfos[i].rejectedClusterIds != null;
if (i % 250 == 0) {
_logger.info("Processed ${offset != null ? i + offset : i} faces");
}
@ -410,6 +413,13 @@ ClusteringResult runLinearClustering(Map args) {
distance > conservativeDistanceThreshold) {
continue;
}
if (faceHasBeenRejectedBefore &&
sortedFaceInfos[j].clusterId != null &&
sortedFaceInfos[i].rejectedClusterIds!.contains(
sortedFaceInfos[j].clusterId!,
)) {
continue;
}
closestDistance = distance;
closestIdx = j;
}

View File

@ -3,6 +3,7 @@ import "dart:typed_data" show Uint8List;
class FaceDbInfoForClustering {
final String faceID;
String? clusterId;
List<String>? rejectedClusterIds;
final Uint8List embeddingBytes;
final double faceScore;
final double blurValue;

View File

@ -334,6 +334,31 @@ class ClusterFeedbackService {
return true;
}
Future<void> addClusterToExistingPerson({
required PersonEntity person,
required String clusterID,
}) async {
if (person.data.rejectedFaceIDs != null &&
person.data.rejectedFaceIDs!.isNotEmpty) {
final clusterFaceIDs =
await MLDataDB.instance.getFaceIDsForCluster(clusterID);
final rejectedLengthBefore = person.data.rejectedFaceIDs!.length;
person.data.rejectedFaceIDs!
.removeWhere((faceID) => clusterFaceIDs.contains(faceID));
final rejectedLengthAfter = person.data.rejectedFaceIDs!.length;
if (rejectedLengthBefore != rejectedLengthAfter) {
_logger.info(
'Removed ${rejectedLengthBefore - rejectedLengthAfter} rejected faces from person ${person.data.name} due to adding cluster $clusterID',
);
await PersonService.instance.updatePerson(person);
}
}
await MLDataDB.instance.assignClusterToPerson(
personID: person.remoteID,
clusterID: clusterID,
);
}
Future<void> ignoreCluster(String clusterID) async {
await PersonService.instance.addPerson('', clusterID, isHidden: true);
Bus.instance.fire(PeopleChangedEvent());

View File

@ -183,6 +183,11 @@ class PersonService {
}) async {
final person = (await getPerson(personID))!;
final personData = person.data;
final clusterInfo = personData.assigned!.firstWhere(
(element) => element.id == clusterID,
);
personData.rejectedFaceIDs ??= [];
personData.rejectedFaceIDs!.addAll(clusterInfo.faces);
personData.assigned!.removeWhere((element) => element.id != clusterID);
await entityService.addOrUpdate(
EntityType.cgroup,
@ -201,6 +206,8 @@ class PersonService {
required Set<String> faceIDs,
}) async {
final personData = person.data;
// Remove faces from clusters
final List<String> emptiedClusters = [];
for (final cluster in personData.assigned!) {
cluster.faces.removeWhere((faceID) => faceIDs.contains(faceID));
@ -219,6 +226,10 @@ class PersonService {
);
}
// Add removed faces to rejected faces
personData.rejectedFaceIDs ??= [];
personData.rejectedFaceIDs!.addAll(faceIDs);
await entityService.addOrUpdate(
EntityType.cgroup,
personData.toJson(),
@ -271,9 +282,16 @@ class PersonService {
entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt));
final Map<String, String> faceIdToClusterID = {};
final Map<String, String> clusterToPersonID = {};
bool shouldCheckRejectedFaces = false;
for (var e in entities) {
final personData = PersonData.fromJson(json.decode(e.data));
if (personData.rejectedFaceIDs != null &&
personData.rejectedFaceIDs!.isNotEmpty) {
shouldCheckRejectedFaces = true;
}
int faceCount = 0;
// Locally store the assignment of faces to clusters and people
for (var cluster in personData.assigned!) {
faceCount += cluster.faces.length;
for (var faceId in cluster.faces) {
@ -303,6 +321,57 @@ class PersonService {
logger.info("Storing feedback for ${faceIdToClusterID.length} faces");
await faceMLDataDB.updateFaceIdToClusterId(faceIdToClusterID);
await faceMLDataDB.bulkAssignClusterToPersonID(clusterToPersonID);
if (shouldCheckRejectedFaces) {
final dbPeopleClusterInfo =
await faceMLDataDB.getPersonToClusterIdToFaceIds();
for (var e in entities) {
final personData = PersonData.fromJson(json.decode(e.data));
if (personData.rejectedFaceIDs != null &&
personData.rejectedFaceIDs!.isNotEmpty) {
final personFaceIDs =
dbPeopleClusterInfo[e.id]!.values.expand((e) => e).toSet();
final rejectedFaceIDsSet = personData.rejectedFaceIDs!.toSet();
final assignedAndRejectedFaceIDs =
rejectedFaceIDsSet.intersection(personFaceIDs);
if (assignedAndRejectedFaceIDs.isNotEmpty) {
// Check that we don't have any empty clusters now
final dbPersonClusterInfo = dbPeopleClusterInfo[e.id]!;
final faceToClusterToRemove = <String, String>{};
for (final clusterIdToFaceIDs in dbPersonClusterInfo.entries) {
final clusterID = clusterIdToFaceIDs.key;
final faceIDs = clusterIdToFaceIDs.value;
final foundRejectedFacesToCluster = <String, String>{};
for (final faceID in faceIDs) {
if (assignedAndRejectedFaceIDs.contains(faceID)) {
faceIDs.remove(faceID);
foundRejectedFacesToCluster[faceID] = clusterID;
}
}
if (faceIDs.isEmpty) {
logger.info(
"Cluster $clusterID for person ${e.id} ${personData.name} is empty due to rejected faces from remote, removing the cluster from person",
);
await faceMLDataDB.removeClusterToPerson(
personID: e.id,
clusterID: clusterID,
);
await faceMLDataDB.captureNotPersonFeedback(
personID: e.id,
clusterID: clusterID,
);
} else {
faceToClusterToRemove.addAll(foundRejectedFacesToCluster);
}
}
// Remove the clusterID for the remaining conflicting faces
await faceMLDataDB.removeFaceIdToClusterId(faceToClusterToRemove);
}
}
}
}
return changed;
}
@ -321,7 +390,7 @@ class PersonService {
final updatedPerson = person.copyWith(
data: person.data.copyWith(avatarFaceId: face.faceID),
);
await _updatePerson(updatedPerson);
await updatePerson(updatedPerson);
}
Future<void> updateAttributes(
@ -342,10 +411,10 @@ class PersonService {
birthDate: birthDate,
),
);
await _updatePerson(updatedPerson);
await updatePerson(updatedPerson);
}
Future<void> _updatePerson(PersonEntity updatePerson) async {
Future<void> updatePerson(PersonEntity updatePerson) async {
await entityService.addOrUpdate(
EntityType.cgroup,
updatePerson.data.toJson(),

View File

@ -250,6 +250,19 @@ class MLService {
_logger.info('Pulling remote feedback before actually clustering');
await PersonService.instance.fetchRemoteClusterFeedback();
final persons = await PersonService.instance.getPersons();
final faceIdNotToCluster = <String, List<String>>{};
for (final person in persons) {
if (person.data.rejectedFaceIDs != null &&
person.data.rejectedFaceIDs!.isNotEmpty) {
final personClusters = person.data.assigned?.map((e) => e.id).toList();
if (personClusters != null) {
for (final faceID in person.data.rejectedFaceIDs!) {
faceIdNotToCluster[faceID] = personClusters;
}
}
}
}
try {
_showClusteringIsHappening = true;
@ -271,6 +284,9 @@ class MLService {
if (!fileIDToCreationTime.containsKey(faceInfo.fileID)) {
missingFileIDs.add(faceInfo.fileID);
} else {
if (faceIdNotToCluster.containsKey(faceInfo.faceID)) {
faceInfo.rejectedClusterIds = faceIdNotToCluster[faceInfo.faceID];
}
allFaceInfoForClustering.add(faceInfo);
}
}

View File

@ -7,7 +7,6 @@ import 'package:flutter/material.dart';
import "package:logging/logging.dart";
import 'package:modal_bottom_sheet/modal_bottom_sheet.dart';
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/generated/l10n.dart";
import "package:photos/models/file/file.dart";
@ -255,8 +254,9 @@ class _PersonActionSheetState extends State<PersonActionSheet> {
return;
}
userAlreadyAssigned = true;
await MLDataDB.instance.assignClusterToPerson(
personID: person.$1.remoteID,
await ClusterFeedbackService.instance
.addClusterToExistingPerson(
person: person.$1,
clusterID: widget.cluserID,
);
Bus.instance.fire(PeopleChangedEvent());

View File

@ -199,8 +199,8 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
);
if (yesOrNo) {
canGiveFeedback = false;
await MLDataDB.instance.assignClusterToPerson(
personID: widget.person.remoteID,
await ClusterFeedbackService.instance.addClusterToExistingPerson(
person: widget.person,
clusterID: clusterID,
);
Bus.instance.fire(PeopleChangedEvent());