[mob][photos] First sort attempt

This commit is contained in:
laurenspriem 2025-01-17 15:01:20 +05:30
parent 1bb04f2650
commit e224609e7d

View File

@ -1,13 +1,17 @@
import 'dart:async';
import "dart:developer";
import "dart:developer" as dev;
import "dart:math" show max;
import "package:flutter/foundation.dart";
import "package:flutter/material.dart";
import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/ente_theme_data.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/generated/l10n.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/l10n/l10n.dart";
import "package:photos/models/file/file.dart";
import "package:photos/models/ml/face/person.dart";
@ -309,7 +313,7 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
stream: _getPersonsWithRecentFileStream(),
builder: (context, snapshot) {
if (snapshot.hasError) {
log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
dev.log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
if (kDebugMode) {
return Column(
mainAxisSize: MainAxisSize.min,
@ -338,6 +342,7 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
if (searchResults.isEmpty) {
return const SizedBox.shrink();
}
final finalResults = await _sortByCosine(searchResults);
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
@ -360,9 +365,9 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
child: ListView.separated(
scrollDirection: Axis.horizontal,
padding: const EdgeInsets.only(right: 8),
itemCount: searchResults.length,
itemCount: finalResults.length,
itemBuilder: (context, index) {
final person = searchResults[index];
final person = finalResults[index];
return PersonGridItem(
key: ValueKey(person.$1.remoteID),
person: person.$1,
@ -406,6 +411,64 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
yield _cachedPersons;
}
Future<List<(PersonEntity, EnteFile)>> _sortByCosine(
List<(PersonEntity, EnteFile)> searchResults,
) async {
if (widget.clusterID == null) return searchResults;
// Get current cluster embedding
final currentClusterSummary =
await MLDataDB.instance.getClusterToClusterSummary([widget.clusterID!]);
final currentClusterEmbeddingData =
currentClusterSummary[widget.clusterID!]?.$1;
if (currentClusterEmbeddingData == null) return searchResults;
final Vector currentClusterEmbedding = Vector.fromList(
EVector.fromBuffer(currentClusterEmbeddingData).values,
dtype: DType.float32,
);
// Get all cluster embeddings
final allClusterSummary = await MLDataDB.instance.getAllClusterSummary();
final persons = searchResults.map((e) => e.$1).toList();
final clusterToPerson = <String, String>{};
for (final person in persons) {
if (person.data.assigned != null) {
for (final cluster in person.data.assigned!) {
clusterToPerson[cluster.id] = person.remoteID;
}
}
}
allClusterSummary
.removeWhere((key, value) => !clusterToPerson.containsKey(key));
final Map<String, Vector> allClusterEmbeddings = allClusterSummary.map(
(key, value) => MapEntry(
key,
Vector.fromList(
EVector.fromBuffer(value.$1).values,
dtype: DType.float32,
),
),
);
// Calculate cosine similarity between current cluster and all clusters
final Map<String, double> personToMaxSimilarity = {};
for (final entry in allClusterEmbeddings.entries) {
final personId = clusterToPerson[entry.key]!;
final similarity = currentClusterEmbedding.dot(entry.value);
personToMaxSimilarity[personId] = max(
personToMaxSimilarity[personId] ?? double.negativeInfinity,
similarity,
);
}
// Sort search results based on cosine similarity
searchResults.sort((a, b) {
final similarityA = personToMaxSimilarity[a.$1.remoteID] ?? 0;
final similarityB = personToMaxSimilarity[b.$1.remoteID] ?? 0;
return similarityB.compareTo(similarityA);
});
return searchResults;
}
Future<void> addNewPerson(
BuildContext context, {
String text = '',