[mob] Add DBSCAN clustering for intra-cluster analysis

This commit is contained in:
laurenspriem 2024-04-03 18:49:43 +05:30
parent b21466bf13
commit 744ded4922
3 changed files with 188 additions and 27 deletions

View File

@ -10,6 +10,7 @@ import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:simple_cluster/simple_cluster.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
@ -29,9 +30,9 @@ class FaceInfo {
});
}
enum ClusterOperation { linearIncrementalClustering }
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
class FaceLinearClustering {
class FaceClustering {
final _logger = Logger("FaceLinearClustering");
Timer? _inactivityTimer;
@ -50,12 +51,12 @@ class FaceLinearClustering {
static const kRecommendedDistanceThreshold = 0.3;
// singleton pattern
FaceLinearClustering._privateConstructor();
FaceClustering._privateConstructor();
/// Use this instance to access the FaceClustering service.
/// e.g. `FaceLinearClustering.instance.predict(dataset)`
static final instance = FaceLinearClustering._privateConstructor();
factory FaceLinearClustering() => instance;
static final instance = FaceClustering._privateConstructor();
factory FaceClustering() => instance;
Future<void> init() async {
return _initLock.synchronized(() async {
@ -103,13 +104,27 @@ class FaceLinearClustering {
final fileIDToCreationTime =
args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final result = FaceLinearClustering._runLinearClustering(
final result = FaceClustering._runLinearClustering(
input,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: distanceThreshold,
);
sendPort.send(result);
break;
case ClusterOperation.dbscanClustering:
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime =
args['fileIDToCreationTime'] as Map<int, int>?;
final eps = args['eps'] as double;
final minPts = args['minPts'] as int;
final result = FaceClustering._runDbscanClustering(
input,
fileIDToCreationTime: fileIDToCreationTime,
eps: eps,
minPts: minPts,
);
sendPort.send(result);
break;
}
} catch (e, stackTrace) {
sendPort
@ -182,7 +197,7 @@ class FaceLinearClustering {
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predict(
Future<Map<String, int>?> predictLinear(
Map<String, (int?, Uint8List)> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
@ -210,7 +225,11 @@ class FaceLinearClustering {
final Map<String, int> faceIdToCluster = await _runInIsolate(
(
ClusterOperation.linearIncrementalClustering,
{'input': input, 'fileIDToCreationTime': fileIDToCreationTime, 'distanceThreshold': distanceThreshold}
{
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'distanceThreshold': distanceThreshold,
}
),
);
// return _runLinearClusteringInComputer(input);
@ -223,6 +242,55 @@ class FaceLinearClustering {
return faceIdToCluster;
}
Future<List<List<String>>> predictDbscan(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double eps = 0.3,
int minPts = 5,
}) async {
if (input.isEmpty) {
_logger.warning(
"DBSCAN Clustering dataset of embeddings is empty, returning empty list.",
);
return [];
}
if (isRunning) {
_logger.warning(
"DBSCAN Clustering is already running, returning empty list.",
);
return [];
}
isRunning = true;
// Clustering inside the isolate
_logger.info(
"Start DBSCAN clustering on ${input.length} embeddings inside computer isolate",
);
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final List<List<String>> clusterFaceIDs = await _runInIsolate(
(
ClusterOperation.dbscanClustering,
{
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'eps': eps,
'minPts': minPts,
}
),
);
// return _runLinearClusteringInComputer(input);
_logger.info(
'DBSCAN Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
);
isRunning = false;
return clusterFaceIDs;
}
static Map<String, int> _runLinearClustering(
Map<String, (int?, Uint8List)> x, {
Map<int, int>? fileIDToCreationTime,
@ -362,7 +430,7 @@ class FaceLinearClustering {
);
// analyze the results
FaceLinearClustering._analyzeClusterResults(sortedFaceInfos);
FaceClustering._analyzeClusterResults(sortedFaceInfos);
return newFaceIdToCluster;
}
@ -424,4 +492,64 @@ class FaceLinearClustering {
"[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms",
);
}
static List<List<String>> _runDbscanClustering(
Map<String, Uint8List> x, {
Map<int, int>? fileIDToCreationTime,
double eps = 0.3,
int minPts = 5,
}) {
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
);
DBSCAN dbscan = DBSCAN(
epsilon: eps,
minPoints: minPts,
distanceMeasure: cosineDistForNormVectors,
);
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in x.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value).values,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}
// Get the embeddings
final List<List<double>> embeddings =
faceInfos.map((faceInfo) => faceInfo.embedding!).toList();
// Run the DBSCAN clustering
final List<List<int>> clusterOutput = dbscan.run(embeddings);
final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
.toList();
final List<List<String>> clusteredFaceIDs = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
.toList();
return clusteredFaceIDs;
}
}

View File

@ -407,7 +407,7 @@ class FaceMlService {
break;
}
final faceIdToCluster = await FaceLinearClustering.instance.predict(
final faceIdToCluster = await FaceClustering.instance.predictLinear(
faceIdToEmbeddingBucket,
fileIDToCreationTime: fileIDToCreationTime,
);
@ -439,7 +439,7 @@ class FaceMlService {
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster = await FaceLinearClustering.instance.predict(
final faceIdToCluster = await FaceClustering.instance.predictLinear(
faceIdToEmbedding,
fileIDToCreationTime: fileIDToCreationTime,
);

View File

@ -4,6 +4,7 @@ import "dart:math" show Random;
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart";
@ -40,6 +41,7 @@ class ClusterFeedbackService {
static setLastViewedClusterID(int clusterID) {
lastViewedClusterID = clusterID;
}
static resetLastViewedClusterID() {
lastViewedClusterID = -1;
}
@ -389,7 +391,10 @@ class ClusterFeedbackService {
}
// TODO: iterate over this method and actually use it
Future<Map<int, List<String>>> breakUpCluster(int clusterID) async {
Future<Map<int, List<String>>> breakUpCluster(
int clusterID, {
useDbscan = true,
}) async {
final faceMlDb = FaceMLDataDB.instance;
final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID);
@ -397,24 +402,52 @@ class ClusterFeedbackService {
final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
embeddings.removeWhere((key, value) => !faceIDs.contains(key));
final clusteringInput = embeddings.map((key, value) {
return MapEntry(key, (null, value));
});
final faceIdToCluster = await FaceLinearClustering.instance
.predict(clusteringInput, distanceThreshold: 0.15);
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
if (faceIdToCluster == null) {
return {};
}
final Map<int, List<String>> clusterIdToFaceIds = {};
if (useDbscan) {
final dbscanClusters = await FaceClustering.instance.predictDbscan(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
eps: 0.25,
minPts: 4,
);
final clusterIdToFaceIds = <int, List<String>>{};
for (final entry in faceIdToCluster.entries) {
final clusterID = entry.value;
if (clusterIdToFaceIds.containsKey(clusterID)) {
clusterIdToFaceIds[clusterID]!.add(entry.key);
} else {
clusterIdToFaceIds[clusterID] = [entry.key];
if (dbscanClusters.isEmpty) {
return {};
}
int maxClusterID = DateTime.now().millisecondsSinceEpoch;
for (final List<String> cluster in dbscanClusters) {
final faceIds = cluster;
clusterIdToFaceIds[maxClusterID] = faceIds;
maxClusterID++;
}
} else {
final clusteringInput = embeddings.map((key, value) {
return MapEntry(key, (null, value));
});
final faceIdToCluster = await FaceClustering.instance.predictLinear(
clusteringInput,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.15,
);
if (faceIdToCluster == null) {
return {};
}
for (final entry in faceIdToCluster.entries) {
final clusterID = entry.value;
if (clusterIdToFaceIds.containsKey(clusterID)) {
clusterIdToFaceIds[clusterID]!.add(entry.key);
} else {
clusterIdToFaceIds[clusterID] = [entry.key];
}
}
}