[mob] Use single db for ml data

This commit is contained in:
Neeraj Gupta 2024-08-14 14:22:02 +05:30
parent 810cf6f885
commit bfec2ff2be
2 changed files with 21 additions and 11 deletions

View File

@ -15,6 +15,8 @@ EntityType typeFromString(String type) {
return EntityType.location; return EntityType.location;
case "person_v2": case "person_v2":
return EntityType.personV2; return EntityType.personV2;
case "personV2":
return EntityType.personV2;
} }
debugPrint("unexpected collection type $type"); debugPrint("unexpected collection type $type");
return EntityType.unknown; return EntityType.unknown;

View File

@ -1,18 +1,19 @@
import "dart:async" show unawaited; import "dart:async" show unawaited;
import "dart:developer" as dev show log; import "dart:developer" as dev show log;
import "dart:math" show min; import "dart:math" show min;
import "dart:typed_data" show ByteData;
import "dart:ui" show Image; import "dart:ui" show Image;
import "package:computer/computer.dart"; import "package:computer/computer.dart";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/core/cache/lru_map.dart"; import "package:photos/core/cache/lru_map.dart";
import "package:photos/core/event_bus.dart"; import "package:photos/core/event_bus.dart";
import "package:photos/db/embeddings_db.dart";
import "package:photos/db/files_db.dart"; import "package:photos/db/files_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/db/ml/embeddings_db.dart";
import 'package:photos/events/embedding_updated_event.dart'; import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart"; import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart"; import "package:photos/service_locator.dart";
import "package:photos/services/collections_service.dart"; import "package:photos/services/collections_service.dart";
@ -57,7 +58,7 @@ class SemanticSearchService {
return; return;
} }
_hasInitialized = true; _hasInitialized = true;
await EmbeddingsDB.instance.init();
await _loadImageEmbeddings(); await _loadImageEmbeddings();
Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) { Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
if (!_hasInitialized) return; if (!_hasInitialized) return;
@ -112,7 +113,7 @@ class SemanticSearchService {
} }
Future<void> clearIndexes() async { Future<void> clearIndexes() async {
await EmbeddingsDB.instance.deleteAll(); await FaceMLDataDB.instance.deleteClipIndexes();
final preferences = await SharedPreferences.getInstance(); final preferences = await SharedPreferences.getInstance();
await preferences.remove("sync_time_embeddings_v3"); await preferences.remove("sync_time_embeddings_v3");
_logger.info("Indexes cleared"); _logger.info("Indexes cleared");
@ -121,7 +122,7 @@ class SemanticSearchService {
Future<void> _loadImageEmbeddings() async { Future<void> _loadImageEmbeddings() async {
_logger.info("Pulling cached embeddings"); _logger.info("Pulling cached embeddings");
final startTime = DateTime.now(); final startTime = DateTime.now();
_cachedImageEmbeddings = await EmbeddingsDB.instance.getAll(); _cachedImageEmbeddings = await FaceMLDataDB.instance.getAll();
final endTime = DateTime.now(); final endTime = DateTime.now();
_logger.info( _logger.info(
"Loading ${_cachedImageEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", "Loading ${_cachedImageEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
@ -133,7 +134,7 @@ class SemanticSearchService {
Future<List<int>> _getFileIDsToBeIndexed() async { Future<List<int>> _getFileIDsToBeIndexed() async {
final uploadedFileIDs = await getIndexableFileIDs(); final uploadedFileIDs = await getIndexableFileIDs();
final embeddedFileIDs = await EmbeddingsDB.instance.getIndexedFileIds(); final embeddedFileIDs = await FaceMLDataDB.instance.getIndexedFileIds();
embeddedFileIDs.removeWhere((key, value) => value < clipMlVersion); embeddedFileIDs.removeWhere((key, value) => value < clipMlVersion);
return uploadedFileIDs.difference(embeddedFileIDs.keys.toSet()).toList(); return uploadedFileIDs.difference(embeddedFileIDs.keys.toSet()).toList();
@ -178,7 +179,7 @@ class SemanticSearchService {
_logger.info(results.length.toString() + " results"); _logger.info(results.length.toString() + " results");
if (deletedEntries.isNotEmpty) { if (deletedEntries.isNotEmpty) {
unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); unawaited(FaceMLDataDB.instance.deleteEmbeddings(deletedEntries));
} }
return results; return results;
@ -221,7 +222,7 @@ class SemanticSearchService {
_logger.info(results.length.toString() + " results"); _logger.info(results.length.toString() + " results");
if (deletedEntries.isNotEmpty) { if (deletedEntries.isNotEmpty) {
unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); unawaited(FaceMLDataDB.instance.deleteEmbeddings(deletedEntries));
} }
final matchingFileIDs = <int>[]; final matchingFileIDs = <int>[];
@ -253,12 +254,12 @@ class SemanticSearchService {
embedding: clipResult.embedding, embedding: clipResult.embedding,
version: clipMlVersion, version: clipMlVersion,
); );
await EmbeddingsDB.instance.put(embedding); await FaceMLDataDB.instance.put(embedding);
} }
static Future<void> storeEmptyClipImageResult(EnteFile entefile) async { static Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!); final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
await EmbeddingsDB.instance.put(embedding); await FaceMLDataDB.instance.put(embedding);
} }
Future<List<double>> _getTextEmbedding(String query) async { Future<List<double>> _getTextEmbedding(String query) async {
@ -320,6 +321,7 @@ List<QueryResult> computeBulkSimilarities(Map args) {
final textEmbedding = args["textEmbedding"] as List<double>; final textEmbedding = args["textEmbedding"] as List<double>;
final minimumSimilarity = args["minimumSimilarity"] ?? final minimumSimilarity = args["minimumSimilarity"] ??
SemanticSearchService.kMinimumSimilarityThreshold; SemanticSearchService.kMinimumSimilarityThreshold;
double bestScore = 0.0;
for (final imageEmbedding in imageEmbeddings) { for (final imageEmbedding in imageEmbeddings) {
final score = computeCosineSimilarity( final score = computeCosineSimilarity(
imageEmbedding.embedding, imageEmbedding.embedding,
@ -328,6 +330,12 @@ List<QueryResult> computeBulkSimilarities(Map args) {
if (score >= minimumSimilarity) { if (score >= minimumSimilarity) {
queryResults.add(QueryResult(imageEmbedding.fileID, score)); queryResults.add(QueryResult(imageEmbedding.fileID, score));
} }
if (score > bestScore) {
bestScore = score;
}
}
if (kDebugMode && queryResults.isEmpty) {
dev.log("No results found for query with best score: $bestScore");
} }
queryResults.sort((first, second) => second.score.compareTo(first.score)); queryResults.sort((first, second) => second.score.compareTo(first.score));