mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[mob][photos] First sort attempt
This commit is contained in:
parent
1bb04f2650
commit
e224609e7d
@ -1,13 +1,17 @@
|
||||
import 'dart:async';
|
||||
import "dart:developer";
|
||||
import "dart:developer" as dev;
|
||||
import "dart:math" show max;
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:flutter/material.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/ml/db.dart";
|
||||
import "package:photos/ente_theme_data.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/generated/l10n.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import "package:photos/l10n/l10n.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/ml/face/person.dart";
|
||||
@ -309,7 +313,7 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
|
||||
stream: _getPersonsWithRecentFileStream(),
|
||||
builder: (context, snapshot) {
|
||||
if (snapshot.hasError) {
|
||||
log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
|
||||
dev.log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
|
||||
if (kDebugMode) {
|
||||
return Column(
|
||||
mainAxisSize: MainAxisSize.min,
|
||||
@ -338,6 +342,7 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
|
||||
if (searchResults.isEmpty) {
|
||||
return const SizedBox.shrink();
|
||||
}
|
||||
final finalResults = await _sortByCosine(searchResults);
|
||||
|
||||
return Column(
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
@ -360,9 +365,9 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
|
||||
child: ListView.separated(
|
||||
scrollDirection: Axis.horizontal,
|
||||
padding: const EdgeInsets.only(right: 8),
|
||||
itemCount: searchResults.length,
|
||||
itemCount: finalResults.length,
|
||||
itemBuilder: (context, index) {
|
||||
final person = searchResults[index];
|
||||
final person = finalResults[index];
|
||||
return PersonGridItem(
|
||||
key: ValueKey(person.$1.remoteID),
|
||||
person: person.$1,
|
||||
@ -406,6 +411,64 @@ class _SaveOrEditPersonState extends State<SaveOrEditPerson> {
|
||||
yield _cachedPersons;
|
||||
}
|
||||
|
||||
Future<List<(PersonEntity, EnteFile)>> _sortByCosine(
|
||||
List<(PersonEntity, EnteFile)> searchResults,
|
||||
) async {
|
||||
if (widget.clusterID == null) return searchResults;
|
||||
|
||||
// Get current cluster embedding
|
||||
final currentClusterSummary =
|
||||
await MLDataDB.instance.getClusterToClusterSummary([widget.clusterID!]);
|
||||
final currentClusterEmbeddingData =
|
||||
currentClusterSummary[widget.clusterID!]?.$1;
|
||||
if (currentClusterEmbeddingData == null) return searchResults;
|
||||
final Vector currentClusterEmbedding = Vector.fromList(
|
||||
EVector.fromBuffer(currentClusterEmbeddingData).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
|
||||
// Get all cluster embeddings
|
||||
final allClusterSummary = await MLDataDB.instance.getAllClusterSummary();
|
||||
final persons = searchResults.map((e) => e.$1).toList();
|
||||
final clusterToPerson = <String, String>{};
|
||||
for (final person in persons) {
|
||||
if (person.data.assigned != null) {
|
||||
for (final cluster in person.data.assigned!) {
|
||||
clusterToPerson[cluster.id] = person.remoteID;
|
||||
}
|
||||
}
|
||||
}
|
||||
allClusterSummary
|
||||
.removeWhere((key, value) => !clusterToPerson.containsKey(key));
|
||||
final Map<String, Vector> allClusterEmbeddings = allClusterSummary.map(
|
||||
(key, value) => MapEntry(
|
||||
key,
|
||||
Vector.fromList(
|
||||
EVector.fromBuffer(value.$1).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
),
|
||||
);
|
||||
// Calculate cosine similarity between current cluster and all clusters
|
||||
final Map<String, double> personToMaxSimilarity = {};
|
||||
for (final entry in allClusterEmbeddings.entries) {
|
||||
final personId = clusterToPerson[entry.key]!;
|
||||
final similarity = currentClusterEmbedding.dot(entry.value);
|
||||
personToMaxSimilarity[personId] = max(
|
||||
personToMaxSimilarity[personId] ?? double.negativeInfinity,
|
||||
similarity,
|
||||
);
|
||||
}
|
||||
// Sort search results based on cosine similarity
|
||||
searchResults.sort((a, b) {
|
||||
final similarityA = personToMaxSimilarity[a.$1.remoteID] ?? 0;
|
||||
final similarityB = personToMaxSimilarity[b.$1.remoteID] ?? 0;
|
||||
return similarityB.compareTo(similarityA);
|
||||
});
|
||||
|
||||
return searchResults;
|
||||
}
|
||||
|
||||
Future<void> addNewPerson(
|
||||
BuildContext context, {
|
||||
String text = '',
|
||||
|
Loading…
x
Reference in New Issue
Block a user