[mob][photos] Fix breakupCluster not returning cluster summaries

This commit is contained in:
laurenspriem 2024-05-29 11:13:21 +05:30
parent c8a3728f5d
commit cfb4ded991
2 changed files with 41 additions and 44 deletions

View File

@ -45,6 +45,8 @@ class ClusteringResult {
bool get isEmpty => newFaceIdToCluster.isEmpty;
bool get hasAllResults => newClusterSummaries != null && newClusterIdToFaceIds != null;
ClusteringResult({
required this.newFaceIdToCluster,
this.newClusterSummaries,
@ -127,8 +129,7 @@ class FaceClusteringService {
break;
}
} catch (e, stackTrace) {
sendPort
.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
sendPort.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
}
});
}
@ -256,6 +257,7 @@ class FaceClusteringService {
Future<ClusteringResult?> predictLinearComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
required Map<int, (Uint8List, int)> oldClusterSummaries,
double distanceThreshold = kRecommendedDistanceThreshold,
}) async {
if (input.isEmpty) {
@ -291,6 +293,7 @@ class FaceClusteringService {
param: {
"input": clusteringInput,
"fileIDToCreationTime": fileIDToCreationTime,
"oldClusterSummaries": oldClusterSummaries,
"distanceThreshold": distanceThreshold,
"conservativeDistanceThreshold": distanceThreshold - 0.08,
"useDynamicThreshold": false,
@ -314,6 +317,7 @@ class FaceClusteringService {
Future<ClusteringResult> predictCompleteComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
required Map<int, (Uint8List, int)> oldClusterSummaries,
double distanceThreshold = kRecommendedDistanceThreshold,
double mergeThreshold = 0.30,
}) async {
@ -336,6 +340,7 @@ class FaceClusteringService {
param: {
"input": input,
"fileIDToCreationTime": fileIDToCreationTime,
"oldClusterSummaries": oldClusterSummaries,
"distanceThreshold": distanceThreshold,
"mergeThreshold": mergeThreshold,
},
@ -355,6 +360,7 @@ class FaceClusteringService {
Future<ClusteringResult?> predictWithinClusterComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
Map<int, (Uint8List, int)> oldClusterSummaries = const <int, (Uint8List, int)>{},
double distanceThreshold = kRecommendedDistanceThreshold,
}) async {
_logger.info(
@ -369,6 +375,7 @@ class FaceClusteringService {
final result = await predictCompleteComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
distanceThreshold: distanceThreshold - 0.08,
mergeThreshold: mergeThreshold,
);
@ -381,6 +388,7 @@ class FaceClusteringService {
final clusterResult = await predictLinearComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
distanceThreshold: distanceThreshold,
);
return clusterResult;
@ -396,12 +404,10 @@ class FaceClusteringService {
final input = args['input'] as Set<FaceInfoForClustering>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final conservativeDistanceThreshold =
args['conservativeDistanceThreshold'] as double;
final conservativeDistanceThreshold = args['conservativeDistanceThreshold'] as double;
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
final offset = args['offset'] as int?;
final oldClusterSummaries =
args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
final oldClusterSummaries = args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
@ -425,8 +431,7 @@ class FaceClusteringService {
dtype: DType.float32,
),
clusterId: face.clusterId,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
),
);
}
@ -493,9 +498,8 @@ class FaceClusteringService {
double closestDistance = double.infinity;
late double thresholdValue;
if (useDynamicThreshold) {
thresholdValue = sortedFaceInfos[i].badFace!
? conservativeDistanceThreshold
: distanceThreshold;
thresholdValue =
sortedFaceInfos[i].badFace! ? conservativeDistanceThreshold : distanceThreshold;
if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++;
} else {
thresholdValue = distanceThreshold;
@ -505,11 +509,10 @@ class FaceClusteringService {
}
// WARNING: The loop below is now O(n^2) so be very careful with anything you put in there!
for (int j = i - 1; j >= 0; j--) {
final double distance = 1 -
sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
final double distance =
1 - sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
if (distance < closestDistance) {
if (sortedFaceInfos[j].badFace! &&
distance > conservativeDistanceThreshold) {
if (sortedFaceInfos[j].badFace! && distance > conservativeDistanceThreshold) {
continue;
}
closestDistance = distance;
@ -535,8 +538,7 @@ class FaceClusteringService {
// Finally, assign the new clusterId to the faces
final Map<String, int> newFaceIdToCluster = {};
final newClusteredFaceInfos =
sortedFaceInfos.sublist(alreadyClusteredCount);
final newClusteredFaceInfos = sortedFaceInfos.sublist(alreadyClusteredCount);
for (final faceInfo in newClusteredFaceInfos) {
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
@ -597,9 +599,8 @@ class FaceClusteringService {
final Map<int, (Uint8List, int)> newClusterSummaries = {};
for (final clusterId in newClusterIdToFaceInfos.keys) {
final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final List<Vector> newEmbeddings =
newClusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList();
final newCount = newEmbeddings.length;
if (oldSummary.containsKey(clusterId)) {
final oldMean = Vector.fromList(
@ -609,8 +610,7 @@ class FaceClusteringService {
final oldCount = oldSummary[clusterId]!.$2;
final oldEmbeddings = oldMean * oldCount;
newEmbeddings.add(oldEmbeddings);
final newMeanVector =
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
newClusterSummaries[clusterId] = (
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
@ -619,10 +619,8 @@ class FaceClusteringService {
} else {
final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
newClusterSummaries[clusterId] = (
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
newCount
);
newClusterSummaries[clusterId] =
(EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
}
}
log(
@ -696,6 +694,7 @@ class FaceClusteringService {
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final mergeThreshold = args['mergeThreshold'] as double;
final oldClusterSummaries = args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
log(
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
@ -711,8 +710,7 @@ class FaceClusteringService {
EVector.fromBuffer(entry.value).values,
dtype: DType.float32,
),
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
@ -746,8 +744,7 @@ class FaceClusteringService {
double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance =
1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
final double distance = 1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
@ -777,21 +774,17 @@ class FaceClusteringService {
}
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
for (final clusterId in clusterIdToFaceInfos.keys) {
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final List<Vector> embeddings =
clusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList();
final count = clusterIdToFaceInfos[clusterId]!.length;
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
final Vector meanEmbeddingNormalized =
meanEmbedding / meanEmbedding.norm();
clusterIdToMeanEmbeddingAndWeight[clusterId] =
(meanEmbeddingNormalized, count);
final Vector meanEmbeddingNormalized = meanEmbedding / meanEmbedding.norm();
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbeddingNormalized, count);
}
// Now merge the clusters that are close to each other, based on mean embedding
final List<(int, int)> mergedClustersList = [];
final List<int> clusterIds =
clusterIdToMeanEmbeddingAndWeight.keys.toList();
final List<int> clusterIds = clusterIdToMeanEmbeddingAndWeight.keys.toList();
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
while (true) {
if (clusterIds.length < 2) break;
@ -858,10 +851,14 @@ class FaceClusteringService {
}
}
final newClusterSummaries = FaceClusteringService.updateClusterSummaries(
oldSummary: <int, (Uint8List, int)>{},
newFaceInfos: faceInfos,
);
// Now calculate the mean of the embeddings for each cluster and update the cluster summaries
Map<int, (Uint8List, int)>? newClusterSummaries;
if (oldClusterSummaries != null) {
newClusterSummaries = FaceClusteringService.updateClusterSummaries(
oldSummary: oldClusterSummaries,
newFaceInfos: faceInfos,
);
}
stopwatchClustering.stop();
log(

View File

@ -434,7 +434,7 @@ class ClusterFeedbackService {
);
if (clusterResult == null ||
clusterResult.newClusterIdToFaceIds == null ||
!clusterResult.hasAllResults ||
clusterResult.isEmpty) {
_logger.warning('No clusters found or something went wrong');
return ClusteringResult(newFaceIdToCluster: {});