import "dart:io"; import "dart:typed_data"; import "package:photos/core/event_bus.dart"; import "package:photos/db/ml/db.dart"; import "package:photos/db/ml/db_fields.dart"; import "package:photos/events/embedding_updated_event.dart"; import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/ml_versions.dart"; import "package:photos/models/ml/vector.dart"; extension ClipDB on MLDataDB { static const databaseName = "ente.embeddings.db"; Future> getAllClipEmbeddings() async { final db = await MLDataDB.instance.asyncDB; final results = await db.getAll('SELECT * FROM $clipTable'); return _convertToEmbeddings(results); } Future> getAllClipVectors() async { final db = await MLDataDB.instance.asyncDB; final results = await db.getAll('SELECT * FROM $clipTable'); return _convertToVectors(results); } // Get indexed FileIDs Future> clipIndexedFileWithVersion() async { final db = await MLDataDB.instance.asyncDB; final maps = await db .getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable'); final Map result = {}; for (final map in maps) { result[map[fileIDColumn] as int] = map[mlVersionColumn] as int; } return result; } Future getClipIndexedFileCount({ int minimumMlVersion = clipMlVersion, }) async { final db = await MLDataDB.instance.asyncDB; final String query = 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $clipTable WHERE $mlVersionColumn >= $minimumMlVersion'; final List> maps = await db.getAll(query); return maps.first['count'] as int; } Future put(ClipEmbedding embedding) async { final db = await MLDataDB.instance.asyncDB; await db.execute( 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)', _getRowFromEmbedding(embedding), ); Bus.instance.fire(EmbeddingUpdatedEvent()); } Future putMany(List embeddings) async { if (embeddings.isEmpty) return; final db = await MLDataDB.instance.asyncDB; final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); await db.executeBatch( 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)', inputs, ); Bus.instance.fire(EmbeddingUpdatedEvent()); } Future deleteClipEmbeddings(List fileIDs) async { final db = await MLDataDB.instance.asyncDB; await db.execute( 'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', ); Bus.instance.fire(EmbeddingUpdatedEvent()); } Future deleteClipIndexes() async { final db = await MLDataDB.instance.asyncDB; await db.execute('DELETE FROM $clipTable'); Bus.instance.fire(EmbeddingUpdatedEvent()); } List _convertToEmbeddings(List> results) { final List embeddings = []; for (final result in results) { final embedding = _getEmbeddingFromRow(result); if (embedding.isEmpty) continue; embeddings.add(embedding); } return embeddings; } List _convertToVectors(List> results) { final List embeddings = []; for (final result in results) { final embedding = _getVectorFromRow(result); if (embedding.isEmpty) continue; embeddings.add(embedding); } return embeddings; } ClipEmbedding _getEmbeddingFromRow(Map row) { final fileID = row[fileIDColumn] as int; final bytes = row[embeddingColumn] as Uint8List; final version = row[mlVersionColumn] as int; final list = Float32List.view(bytes.buffer); return ClipEmbedding(fileID: fileID, embedding: list, version: version); } EmbeddingVector _getVectorFromRow(Map row) { final fileID = row[fileIDColumn] as int; final bytes = row[embeddingColumn] as Uint8List; final list = Float32List.view(bytes.buffer); return EmbeddingVector(fileID: fileID, embedding: list); } List _getRowFromEmbedding(ClipEmbedding embedding) { return [ embedding.fileID, Float32List.fromList(embedding.embedding).buffer.asUint8List(), embedding.version, ]; } Future _clearDeprecatedStores(Directory dir) async { final deprecatedObjectBox = Directory(dir.path + "/object-box-store"); if (await deprecatedObjectBox.exists()) { await deprecatedObjectBox.delete(recursive: true); } final deprecatedIsar = File(dir.path + "/default.isar"); if (await deprecatedIsar.exists()) { await deprecatedIsar.delete(); } } }