mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[mob] Add DBSCAN clustering for intra-cluster analysis
This commit is contained in:
parent
b21466bf13
commit
744ded4922
@ -10,6 +10,7 @@ import "package:ml_linalg/vector.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:simple_cluster/simple_cluster.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
class FaceInfo {
|
||||
@ -29,9 +30,9 @@ class FaceInfo {
|
||||
});
|
||||
}
|
||||
|
||||
enum ClusterOperation { linearIncrementalClustering }
|
||||
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
|
||||
|
||||
class FaceLinearClustering {
|
||||
class FaceClustering {
|
||||
final _logger = Logger("FaceLinearClustering");
|
||||
|
||||
Timer? _inactivityTimer;
|
||||
@ -50,12 +51,12 @@ class FaceLinearClustering {
|
||||
static const kRecommendedDistanceThreshold = 0.3;
|
||||
|
||||
// singleton pattern
|
||||
FaceLinearClustering._privateConstructor();
|
||||
FaceClustering._privateConstructor();
|
||||
|
||||
/// Use this instance to access the FaceClustering service.
|
||||
/// e.g. `FaceLinearClustering.instance.predict(dataset)`
|
||||
static final instance = FaceLinearClustering._privateConstructor();
|
||||
factory FaceLinearClustering() => instance;
|
||||
static final instance = FaceClustering._privateConstructor();
|
||||
factory FaceClustering() => instance;
|
||||
|
||||
Future<void> init() async {
|
||||
return _initLock.synchronized(() async {
|
||||
@ -103,13 +104,27 @@ class FaceLinearClustering {
|
||||
final fileIDToCreationTime =
|
||||
args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
final result = FaceLinearClustering._runLinearClustering(
|
||||
final result = FaceClustering._runLinearClustering(
|
||||
input,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
distanceThreshold: distanceThreshold,
|
||||
);
|
||||
sendPort.send(result);
|
||||
break;
|
||||
case ClusterOperation.dbscanClustering:
|
||||
final input = args['input'] as Map<String, Uint8List>;
|
||||
final fileIDToCreationTime =
|
||||
args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final eps = args['eps'] as double;
|
||||
final minPts = args['minPts'] as int;
|
||||
final result = FaceClustering._runDbscanClustering(
|
||||
input,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
eps: eps,
|
||||
minPts: minPts,
|
||||
);
|
||||
sendPort.send(result);
|
||||
break;
|
||||
}
|
||||
} catch (e, stackTrace) {
|
||||
sendPort
|
||||
@ -182,7 +197,7 @@ class FaceLinearClustering {
|
||||
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
|
||||
///
|
||||
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
|
||||
Future<Map<String, int>?> predict(
|
||||
Future<Map<String, int>?> predictLinear(
|
||||
Map<String, (int?, Uint8List)> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
@ -210,7 +225,11 @@ class FaceLinearClustering {
|
||||
final Map<String, int> faceIdToCluster = await _runInIsolate(
|
||||
(
|
||||
ClusterOperation.linearIncrementalClustering,
|
||||
{'input': input, 'fileIDToCreationTime': fileIDToCreationTime, 'distanceThreshold': distanceThreshold}
|
||||
{
|
||||
'input': input,
|
||||
'fileIDToCreationTime': fileIDToCreationTime,
|
||||
'distanceThreshold': distanceThreshold,
|
||||
}
|
||||
),
|
||||
);
|
||||
// return _runLinearClusteringInComputer(input);
|
||||
@ -223,6 +242,55 @@ class FaceLinearClustering {
|
||||
return faceIdToCluster;
|
||||
}
|
||||
|
||||
Future<List<List<String>>> predictDbscan(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double eps = 0.3,
|
||||
int minPts = 5,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
_logger.warning(
|
||||
"DBSCAN Clustering dataset of embeddings is empty, returning empty list.",
|
||||
);
|
||||
return [];
|
||||
}
|
||||
if (isRunning) {
|
||||
_logger.warning(
|
||||
"DBSCAN Clustering is already running, returning empty list.",
|
||||
);
|
||||
return [];
|
||||
}
|
||||
|
||||
isRunning = true;
|
||||
|
||||
// Clustering inside the isolate
|
||||
_logger.info(
|
||||
"Start DBSCAN clustering on ${input.length} embeddings inside computer isolate",
|
||||
);
|
||||
final stopwatchClustering = Stopwatch()..start();
|
||||
// final Map<String, int> faceIdToCluster =
|
||||
// await _runLinearClusteringInComputer(input);
|
||||
final List<List<String>> clusterFaceIDs = await _runInIsolate(
|
||||
(
|
||||
ClusterOperation.dbscanClustering,
|
||||
{
|
||||
'input': input,
|
||||
'fileIDToCreationTime': fileIDToCreationTime,
|
||||
'eps': eps,
|
||||
'minPts': minPts,
|
||||
}
|
||||
),
|
||||
);
|
||||
// return _runLinearClusteringInComputer(input);
|
||||
_logger.info(
|
||||
'DBSCAN Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
|
||||
);
|
||||
|
||||
isRunning = false;
|
||||
|
||||
return clusterFaceIDs;
|
||||
}
|
||||
|
||||
static Map<String, int> _runLinearClustering(
|
||||
Map<String, (int?, Uint8List)> x, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
@ -362,7 +430,7 @@ class FaceLinearClustering {
|
||||
);
|
||||
|
||||
// analyze the results
|
||||
FaceLinearClustering._analyzeClusterResults(sortedFaceInfos);
|
||||
FaceClustering._analyzeClusterResults(sortedFaceInfos);
|
||||
|
||||
return newFaceIdToCluster;
|
||||
}
|
||||
@ -424,4 +492,64 @@ class FaceLinearClustering {
|
||||
"[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms",
|
||||
);
|
||||
}
|
||||
|
||||
static List<List<String>> _runDbscanClustering(
|
||||
Map<String, Uint8List> x, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double eps = 0.3,
|
||||
int minPts = 5,
|
||||
}) {
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
|
||||
);
|
||||
|
||||
DBSCAN dbscan = DBSCAN(
|
||||
epsilon: eps,
|
||||
minPoints: minPts,
|
||||
distanceMeasure: cosineDistForNormVectors,
|
||||
);
|
||||
|
||||
// Organize everything into a list of FaceInfo objects
|
||||
final List<FaceInfo> faceInfos = [];
|
||||
for (final entry in x.entries) {
|
||||
faceInfos.add(
|
||||
FaceInfo(
|
||||
faceID: entry.key,
|
||||
embedding: EVector.fromBuffer(entry.value).values,
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
||||
if (fileIDToCreationTime != null) {
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Get the embeddings
|
||||
final List<List<double>> embeddings =
|
||||
faceInfos.map((faceInfo) => faceInfo.embedding!).toList();
|
||||
|
||||
// Run the DBSCAN clustering
|
||||
final List<List<int>> clusterOutput = dbscan.run(embeddings);
|
||||
final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
|
||||
.map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
|
||||
.toList();
|
||||
final List<List<String>> clusteredFaceIDs = clusterOutput
|
||||
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
|
||||
.toList();
|
||||
|
||||
return clusteredFaceIDs;
|
||||
}
|
||||
}
|
||||
|
@ -407,7 +407,7 @@ class FaceMlService {
|
||||
break;
|
||||
}
|
||||
|
||||
final faceIdToCluster = await FaceLinearClustering.instance.predict(
|
||||
final faceIdToCluster = await FaceClustering.instance.predictLinear(
|
||||
faceIdToEmbeddingBucket,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
);
|
||||
@ -439,7 +439,7 @@ class FaceMlService {
|
||||
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
|
||||
|
||||
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
|
||||
final faceIdToCluster = await FaceLinearClustering.instance.predict(
|
||||
final faceIdToCluster = await FaceClustering.instance.predictLinear(
|
||||
faceIdToEmbedding,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import "dart:math" show Random;
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/extensions/stop_watch.dart";
|
||||
import "package:photos/face/db.dart";
|
||||
@ -40,6 +41,7 @@ class ClusterFeedbackService {
|
||||
static setLastViewedClusterID(int clusterID) {
|
||||
lastViewedClusterID = clusterID;
|
||||
}
|
||||
|
||||
static resetLastViewedClusterID() {
|
||||
lastViewedClusterID = -1;
|
||||
}
|
||||
@ -389,7 +391,10 @@ class ClusterFeedbackService {
|
||||
}
|
||||
|
||||
// TODO: iterate over this method and actually use it
|
||||
Future<Map<int, List<String>>> breakUpCluster(int clusterID) async {
|
||||
Future<Map<int, List<String>>> breakUpCluster(
|
||||
int clusterID, {
|
||||
useDbscan = true,
|
||||
}) async {
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
|
||||
final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID);
|
||||
@ -397,24 +402,52 @@ class ClusterFeedbackService {
|
||||
|
||||
final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
|
||||
embeddings.removeWhere((key, value) => !faceIDs.contains(key));
|
||||
final clusteringInput = embeddings.map((key, value) {
|
||||
return MapEntry(key, (null, value));
|
||||
});
|
||||
|
||||
final faceIdToCluster = await FaceLinearClustering.instance
|
||||
.predict(clusteringInput, distanceThreshold: 0.15);
|
||||
final fileIDToCreationTime =
|
||||
await FilesDB.instance.getFileIDToCreationTime();
|
||||
|
||||
if (faceIdToCluster == null) {
|
||||
return {};
|
||||
}
|
||||
final Map<int, List<String>> clusterIdToFaceIds = {};
|
||||
if (useDbscan) {
|
||||
final dbscanClusters = await FaceClustering.instance.predictDbscan(
|
||||
embeddings,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
eps: 0.25,
|
||||
minPts: 4,
|
||||
);
|
||||
|
||||
final clusterIdToFaceIds = <int, List<String>>{};
|
||||
for (final entry in faceIdToCluster.entries) {
|
||||
final clusterID = entry.value;
|
||||
if (clusterIdToFaceIds.containsKey(clusterID)) {
|
||||
clusterIdToFaceIds[clusterID]!.add(entry.key);
|
||||
} else {
|
||||
clusterIdToFaceIds[clusterID] = [entry.key];
|
||||
if (dbscanClusters.isEmpty) {
|
||||
return {};
|
||||
}
|
||||
|
||||
int maxClusterID = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
for (final List<String> cluster in dbscanClusters) {
|
||||
final faceIds = cluster;
|
||||
clusterIdToFaceIds[maxClusterID] = faceIds;
|
||||
maxClusterID++;
|
||||
}
|
||||
} else {
|
||||
final clusteringInput = embeddings.map((key, value) {
|
||||
return MapEntry(key, (null, value));
|
||||
});
|
||||
|
||||
final faceIdToCluster = await FaceClustering.instance.predictLinear(
|
||||
clusteringInput,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
distanceThreshold: 0.15,
|
||||
);
|
||||
|
||||
if (faceIdToCluster == null) {
|
||||
return {};
|
||||
}
|
||||
|
||||
for (final entry in faceIdToCluster.entries) {
|
||||
final clusterID = entry.value;
|
||||
if (clusterIdToFaceIds.containsKey(clusterID)) {
|
||||
clusterIdToFaceIds[clusterID]!.add(entry.key);
|
||||
} else {
|
||||
clusterIdToFaceIds[clusterID] = [entry.key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user