[mob] Put clip queries in same db class

This commit is contained in:
Neeraj Gupta 2025-01-27 16:27:20 +05:30
parent 771d12bd9b
commit 69661b0d30
6 changed files with 111 additions and 109 deletions

View File

@ -1,6 +1,8 @@
import "dart:typed_data";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart";
abstract class IMLDataDB {
@ -103,4 +105,11 @@ abstract class IMLDataDB {
Future<Set<int>> getAllFilesAssociatedWithAllClusters({
List<String>? exceptClusters,
});
Future<List<EmbeddingVector>> getAllClipVectors();
Future<Map<int, int>> clipIndexedFileWithVersion();
Future<int> getClipIndexedFileCount({int minimumMlVersion});
Future<void> putClip(List<ClipEmbedding> embeddings);
Future<void> deleteClipEmbeddings(List<int> fileIDs);
Future<void> deleteClipIndexes();
}

View File

@ -1,100 +0,0 @@
import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/db/ml/db_fields.dart";
import "package:photos/events/embedding_updated_event.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/models/ml/vector.dart";
extension ClipDB on MLDataDB {
Future<List<EmbeddingVector>> getAllClipVectors() async {
Logger("ClipDB").info("reading all embeddings from DB");
final db = await MLDataDB.instance.asyncDB;
final results = await db.getAll('SELECT * FROM $clipTable');
return _convertToVectors(results);
}
// Get indexed FileIDs
Future<Map<int, int>> clipIndexedFileWithVersion() async {
final db = await MLDataDB.instance.asyncDB;
final maps = await db
.getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable');
final Map<int, int> result = {};
for (final map in maps) {
result[map[fileIDColumn] as int] = map[mlVersionColumn] as int;
}
return result;
}
Future<int> getClipIndexedFileCount({
int minimumMlVersion = clipMlVersion,
}) async {
final db = await MLDataDB.instance.asyncDB;
final String query =
'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $clipTable WHERE $mlVersionColumn >= $minimumMlVersion';
final List<Map<String, dynamic>> maps = await db.getAll(query);
return maps.first['count'] as int;
}
Future<void> put(ClipEmbedding embedding) async {
final db = await MLDataDB.instance.asyncDB;
await db.execute(
'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)',
_getRowFromEmbedding(embedding),
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<void> putMany(List<ClipEmbedding> embeddings) async {
if (embeddings.isEmpty) return;
final db = await MLDataDB.instance.asyncDB;
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
await db.executeBatch(
'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)',
inputs,
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<void> deleteClipEmbeddings(List<int> fileIDs) async {
final db = await MLDataDB.instance.asyncDB;
await db.execute(
'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})',
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<void> deleteClipIndexes() async {
final db = await MLDataDB.instance.asyncDB;
await db.execute('DELETE FROM $clipTable');
Bus.instance.fire(EmbeddingUpdatedEvent());
}
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
final List<EmbeddingVector> embeddings = [];
for (final result in results) {
final embedding = _getVectorFromRow(result);
if (embedding.isEmpty) continue;
embeddings.add(embedding);
}
return embeddings;
}
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
final fileID = row[fileIDColumn] as int;
final bytes = row[embeddingColumn] as Uint8List;
final list = Float32List.view(bytes.buffer);
return EmbeddingVector(fileID: fileID, embedding: list);
}
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
return [
embedding.fileID,
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
embedding.version,
];
}
}

View File

@ -6,12 +6,16 @@ import "package:flutter/foundation.dart";
import 'package:logging/logging.dart';
import 'package:path/path.dart' show join;
import 'package:path_provider/path_provider.dart';
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/base.dart";
import 'package:photos/db/ml/db_fields.dart';
import "package:photos/db/ml/db_model_mappers.dart";
import "package:photos/events/embedding_updated_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/models/ml/vector.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import "package:photos/services/machine_learning/ml_result.dart";
@ -1151,4 +1155,96 @@ class MLDataDB extends IMLDataDB {
return <int>{for (final row in result) row[fileIDColumn]};
}
@override
Future<List<EmbeddingVector>> getAllClipVectors() async {
Logger("ClipDB").info("reading all embeddings from DB");
final db = await MLDataDB.instance.asyncDB;
final results = await db.getAll('SELECT * FROM $clipTable');
return _convertToVectors(results);
}
// Get indexed FileIDs
@override
Future<Map<int, int>> clipIndexedFileWithVersion() async {
final db = await MLDataDB.instance.asyncDB;
final maps = await db
.getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable');
final Map<int, int> result = {};
for (final map in maps) {
result[map[fileIDColumn] as int] = map[mlVersionColumn] as int;
}
return result;
}
@override
Future<int> getClipIndexedFileCount({
int minimumMlVersion = clipMlVersion,
}) async {
final db = await MLDataDB.instance.asyncDB;
final String query =
'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $clipTable WHERE $mlVersionColumn >= $minimumMlVersion';
final List<Map<String, dynamic>> maps = await db.getAll(query);
return maps.first['count'] as int;
}
@override
Future<void> putClip(List<ClipEmbedding> embeddings) async {
if (embeddings.isEmpty) return;
final db = await MLDataDB.instance.asyncDB;
if (embeddings.length == 1) {
await db.execute(
'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)',
_getRowFromEmbedding(embeddings.first),
);
} else {
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
await db.executeBatch(
'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)',
inputs,
);
}
Bus.instance.fire(EmbeddingUpdatedEvent());
}
@override
Future<void> deleteClipEmbeddings(List<int> fileIDs) async {
final db = await MLDataDB.instance.asyncDB;
await db.execute(
'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})',
);
Bus.instance.fire(EmbeddingUpdatedEvent());
}
@override
Future<void> deleteClipIndexes() async {
final db = await MLDataDB.instance.asyncDB;
await db.execute('DELETE FROM $clipTable');
Bus.instance.fire(EmbeddingUpdatedEvent());
}
List<EmbeddingVector> _convertToVectors(List<Map<String, dynamic>> results) {
final List<EmbeddingVector> embeddings = [];
for (final result in results) {
final embedding = _getVectorFromRow(result);
if (embedding.isEmpty) continue;
embeddings.add(embedding);
}
return embeddings;
}
EmbeddingVector _getVectorFromRow(Map<String, dynamic> row) {
final fileID = row[fileIDColumn] as int;
final bytes = row[embeddingColumn] as Uint8List;
final list = Float32List.view(bytes.buffer);
return EmbeddingVector(fileID: fileID, embedding: list);
}
List<Object?> _getRowFromEmbedding(ClipEmbedding embedding) {
return [
embedding.fileID,
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
embedding.version,
];
}
}

View File

@ -10,7 +10,6 @@ import "package:ml_linalg/vector.dart";
import "package:photos/core/cache/lru_map.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/models/file/file.dart";
@ -68,7 +67,8 @@ class SemanticSearchService {
bool isMagicSearchEnabledAndReady() {
return userRemoteFlagService
.getCachedBoolValue(UserRemoteFlagService.mlEnabled) && _textModelIsLoaded;
.getCachedBoolValue(UserRemoteFlagService.mlEnabled) &&
_textModelIsLoaded;
}
// searchScreenQuery should only be used for the user initiate query on the search screen.
@ -77,8 +77,7 @@ class SemanticSearchService {
if (!isMagicSearchEnabledAndReady()) {
if (flagService.internalUser) {
_logger.info(
"ML global consent: ${userRemoteFlagService
.getCachedBoolValue(UserRemoteFlagService.mlEnabled)}, loaded: $_textModelIsLoaded ",
"ML global consent: ${userRemoteFlagService.getCachedBoolValue(UserRemoteFlagService.mlEnabled)}, loaded: $_textModelIsLoaded ",
);
}
return (query, <EnteFile>[]);
@ -221,12 +220,12 @@ class SemanticSearchService {
embedding: clipResult.embedding,
version: clipMlVersion,
);
await MLDataDB.instance.put(embedding);
await MLDataDB.instance.putClip([embedding]);
}
static Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
await MLDataDB.instance.put(embedding);
await MLDataDB.instance.putClip([embedding]);
}
Future<List<double>> _getTextEmbedding(String query) async {

View File

@ -1,7 +1,6 @@
import "package:flutter/material.dart";
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/service_locator.dart";

View File

@ -4,7 +4,6 @@ import "dart:math" as math show sqrt, min, max;
import "package:flutter/services.dart" show PlatformException;
import "package:logging/logging.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/db/ml/filedata.dart";
import "package:photos/extensions/list.dart";
@ -249,7 +248,7 @@ Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
}
}
await MLDataDB.instance.bulkInsertFaces(faces);
await MLDataDB.instance.putMany(clipEmbeddings);
await MLDataDB.instance.putClip(clipEmbeddings);
}
// Yield any remaining instructions
if (batchToYield.isNotEmpty) {