mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[mob][photos] Fix breakupCluster not returning cluster summaries
This commit is contained in:
parent
c8a3728f5d
commit
cfb4ded991
@ -45,6 +45,8 @@ class ClusteringResult {
|
||||
|
||||
bool get isEmpty => newFaceIdToCluster.isEmpty;
|
||||
|
||||
bool get hasAllResults => newClusterSummaries != null && newClusterIdToFaceIds != null;
|
||||
|
||||
ClusteringResult({
|
||||
required this.newFaceIdToCluster,
|
||||
this.newClusterSummaries,
|
||||
@ -127,8 +129,7 @@ class FaceClusteringService {
|
||||
break;
|
||||
}
|
||||
} catch (e, stackTrace) {
|
||||
sendPort
|
||||
.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
|
||||
sendPort.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -256,6 +257,7 @@ class FaceClusteringService {
|
||||
Future<ClusteringResult?> predictLinearComputer(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
required Map<int, (Uint8List, int)> oldClusterSummaries,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
@ -291,6 +293,7 @@ class FaceClusteringService {
|
||||
param: {
|
||||
"input": clusteringInput,
|
||||
"fileIDToCreationTime": fileIDToCreationTime,
|
||||
"oldClusterSummaries": oldClusterSummaries,
|
||||
"distanceThreshold": distanceThreshold,
|
||||
"conservativeDistanceThreshold": distanceThreshold - 0.08,
|
||||
"useDynamicThreshold": false,
|
||||
@ -314,6 +317,7 @@ class FaceClusteringService {
|
||||
Future<ClusteringResult> predictCompleteComputer(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
required Map<int, (Uint8List, int)> oldClusterSummaries,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
double mergeThreshold = 0.30,
|
||||
}) async {
|
||||
@ -336,6 +340,7 @@ class FaceClusteringService {
|
||||
param: {
|
||||
"input": input,
|
||||
"fileIDToCreationTime": fileIDToCreationTime,
|
||||
"oldClusterSummaries": oldClusterSummaries,
|
||||
"distanceThreshold": distanceThreshold,
|
||||
"mergeThreshold": mergeThreshold,
|
||||
},
|
||||
@ -355,6 +360,7 @@ class FaceClusteringService {
|
||||
Future<ClusteringResult?> predictWithinClusterComputer(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
Map<int, (Uint8List, int)> oldClusterSummaries = const <int, (Uint8List, int)>{},
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
}) async {
|
||||
_logger.info(
|
||||
@ -369,6 +375,7 @@ class FaceClusteringService {
|
||||
final result = await predictCompleteComputer(
|
||||
input,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
oldClusterSummaries: oldClusterSummaries,
|
||||
distanceThreshold: distanceThreshold - 0.08,
|
||||
mergeThreshold: mergeThreshold,
|
||||
);
|
||||
@ -381,6 +388,7 @@ class FaceClusteringService {
|
||||
final clusterResult = await predictLinearComputer(
|
||||
input,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
oldClusterSummaries: oldClusterSummaries,
|
||||
distanceThreshold: distanceThreshold,
|
||||
);
|
||||
return clusterResult;
|
||||
@ -396,12 +404,10 @@ class FaceClusteringService {
|
||||
final input = args['input'] as Set<FaceInfoForClustering>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
final conservativeDistanceThreshold =
|
||||
args['conservativeDistanceThreshold'] as double;
|
||||
final conservativeDistanceThreshold = args['conservativeDistanceThreshold'] as double;
|
||||
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
|
||||
final offset = args['offset'] as int?;
|
||||
final oldClusterSummaries =
|
||||
args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
|
||||
final oldClusterSummaries = args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
|
||||
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
|
||||
@ -425,8 +431,7 @@ class FaceClusteringService {
|
||||
dtype: DType.float32,
|
||||
),
|
||||
clusterId: face.clusterId,
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
|
||||
fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
|
||||
),
|
||||
);
|
||||
}
|
||||
@ -493,9 +498,8 @@ class FaceClusteringService {
|
||||
double closestDistance = double.infinity;
|
||||
late double thresholdValue;
|
||||
if (useDynamicThreshold) {
|
||||
thresholdValue = sortedFaceInfos[i].badFace!
|
||||
? conservativeDistanceThreshold
|
||||
: distanceThreshold;
|
||||
thresholdValue =
|
||||
sortedFaceInfos[i].badFace! ? conservativeDistanceThreshold : distanceThreshold;
|
||||
if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++;
|
||||
} else {
|
||||
thresholdValue = distanceThreshold;
|
||||
@ -505,11 +509,10 @@ class FaceClusteringService {
|
||||
}
|
||||
// WARNING: The loop below is now O(n^2) so be very careful with anything you put in there!
|
||||
for (int j = i - 1; j >= 0; j--) {
|
||||
final double distance = 1 -
|
||||
sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
|
||||
final double distance =
|
||||
1 - sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
|
||||
if (distance < closestDistance) {
|
||||
if (sortedFaceInfos[j].badFace! &&
|
||||
distance > conservativeDistanceThreshold) {
|
||||
if (sortedFaceInfos[j].badFace! && distance > conservativeDistanceThreshold) {
|
||||
continue;
|
||||
}
|
||||
closestDistance = distance;
|
||||
@ -535,8 +538,7 @@ class FaceClusteringService {
|
||||
|
||||
// Finally, assign the new clusterId to the faces
|
||||
final Map<String, int> newFaceIdToCluster = {};
|
||||
final newClusteredFaceInfos =
|
||||
sortedFaceInfos.sublist(alreadyClusteredCount);
|
||||
final newClusteredFaceInfos = sortedFaceInfos.sublist(alreadyClusteredCount);
|
||||
for (final faceInfo in newClusteredFaceInfos) {
|
||||
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
||||
}
|
||||
@ -597,9 +599,8 @@ class FaceClusteringService {
|
||||
|
||||
final Map<int, (Uint8List, int)> newClusterSummaries = {};
|
||||
for (final clusterId in newClusterIdToFaceInfos.keys) {
|
||||
final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
|
||||
.map((faceInfo) => faceInfo.vEmbedding!)
|
||||
.toList();
|
||||
final List<Vector> newEmbeddings =
|
||||
newClusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList();
|
||||
final newCount = newEmbeddings.length;
|
||||
if (oldSummary.containsKey(clusterId)) {
|
||||
final oldMean = Vector.fromList(
|
||||
@ -609,8 +610,7 @@ class FaceClusteringService {
|
||||
final oldCount = oldSummary[clusterId]!.$2;
|
||||
final oldEmbeddings = oldMean * oldCount;
|
||||
newEmbeddings.add(oldEmbeddings);
|
||||
final newMeanVector =
|
||||
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
|
||||
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
|
||||
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
|
||||
newClusterSummaries[clusterId] = (
|
||||
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
|
||||
@ -619,10 +619,8 @@ class FaceClusteringService {
|
||||
} else {
|
||||
final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
|
||||
final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
|
||||
newClusterSummaries[clusterId] = (
|
||||
EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
|
||||
newCount
|
||||
);
|
||||
newClusterSummaries[clusterId] =
|
||||
(EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
|
||||
}
|
||||
}
|
||||
log(
|
||||
@ -696,6 +694,7 @@ class FaceClusteringService {
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
final mergeThreshold = args['mergeThreshold'] as double;
|
||||
final oldClusterSummaries = args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
|
||||
|
||||
log(
|
||||
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
||||
@ -711,8 +710,7 @@ class FaceClusteringService {
|
||||
EVector.fromBuffer(entry.value).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
),
|
||||
);
|
||||
}
|
||||
@ -746,8 +744,7 @@ class FaceClusteringService {
|
||||
double closestDistance = double.infinity;
|
||||
for (int j = 0; j < totalFaces; j++) {
|
||||
if (i == j) continue;
|
||||
final double distance =
|
||||
1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
||||
final double distance = 1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
@ -777,21 +774,17 @@ class FaceClusteringService {
|
||||
}
|
||||
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
|
||||
for (final clusterId in clusterIdToFaceInfos.keys) {
|
||||
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
|
||||
.map((faceInfo) => faceInfo.vEmbedding!)
|
||||
.toList();
|
||||
final List<Vector> embeddings =
|
||||
clusterIdToFaceInfos[clusterId]!.map((faceInfo) => faceInfo.vEmbedding!).toList();
|
||||
final count = clusterIdToFaceInfos[clusterId]!.length;
|
||||
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
|
||||
final Vector meanEmbeddingNormalized =
|
||||
meanEmbedding / meanEmbedding.norm();
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterId] =
|
||||
(meanEmbeddingNormalized, count);
|
||||
final Vector meanEmbeddingNormalized = meanEmbedding / meanEmbedding.norm();
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbeddingNormalized, count);
|
||||
}
|
||||
|
||||
// Now merge the clusters that are close to each other, based on mean embedding
|
||||
final List<(int, int)> mergedClustersList = [];
|
||||
final List<int> clusterIds =
|
||||
clusterIdToMeanEmbeddingAndWeight.keys.toList();
|
||||
final List<int> clusterIds = clusterIdToMeanEmbeddingAndWeight.keys.toList();
|
||||
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
|
||||
while (true) {
|
||||
if (clusterIds.length < 2) break;
|
||||
@ -858,10 +851,14 @@ class FaceClusteringService {
|
||||
}
|
||||
}
|
||||
|
||||
final newClusterSummaries = FaceClusteringService.updateClusterSummaries(
|
||||
oldSummary: <int, (Uint8List, int)>{},
|
||||
newFaceInfos: faceInfos,
|
||||
);
|
||||
// Now calculate the mean of the embeddings for each cluster and update the cluster summaries
|
||||
Map<int, (Uint8List, int)>? newClusterSummaries;
|
||||
if (oldClusterSummaries != null) {
|
||||
newClusterSummaries = FaceClusteringService.updateClusterSummaries(
|
||||
oldSummary: oldClusterSummaries,
|
||||
newFaceInfos: faceInfos,
|
||||
);
|
||||
}
|
||||
|
||||
stopwatchClustering.stop();
|
||||
log(
|
||||
|
@ -434,7 +434,7 @@ class ClusterFeedbackService {
|
||||
);
|
||||
|
||||
if (clusterResult == null ||
|
||||
clusterResult.newClusterIdToFaceIds == null ||
|
||||
!clusterResult.hasAllResults ||
|
||||
clusterResult.isEmpty) {
|
||||
_logger.warning('No clusters found or something went wrong');
|
||||
return ClusteringResult(newFaceIdToCluster: {});
|
||||
|
Loading…
x
Reference in New Issue
Block a user