diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index 587713952d..db0f52d82b 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -37,6 +37,7 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import "package:photos/services/machine_learning/machine_learning_controller.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; +import "package:photos/services/magic_cache_service.dart"; import 'package:photos/services/memories_service.dart'; import 'package:photos/services/push_service.dart'; import 'package:photos/services/remote_sync_service.dart'; @@ -303,6 +304,8 @@ Future _init(bool isBackground, {String via = ''}) async { preferences, ); + MagicCacheService.instance.init(preferences); + initComplete = true; _logger.info("Initialization done"); } catch (e, s) { diff --git a/mobile/lib/models/search/search_types.dart b/mobile/lib/models/search/search_types.dart index e6dab467e1..1b5a186522 100644 --- a/mobile/lib/models/search/search_types.dart +++ b/mobile/lib/models/search/search_types.dart @@ -13,7 +13,6 @@ import "package:photos/models/collection/collection_items.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/typedefs.dart"; import "package:photos/services/collections_service.dart"; -import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart"; import "package:photos/services/search_service.dart"; import "package:photos/ui/viewer/gallery/collection_page.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; @@ -292,8 +291,6 @@ extension SectionTypeExtensions on SectionType { switch (this) { case SectionType.location: return [Bus.instance.on()]; - case SectionType.magic: - return [Bus.instance.on()]; default: return []; } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index bfe8d9cec7..44fc961f3f 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -267,6 +267,49 @@ class SemanticSearchService { return results; } + Future> getMatchingFileIDs(String query, double minScore) async { + final textEmbedding = await _getTextEmbedding(query); + + final queryResults = + await _getScores(textEmbedding, scoreThreshold: minScore); + + final queryResultIds = []; + for (QueryResult result in queryResults) { + queryResultIds.add(result.id); + } + + final filesMap = await FilesDB.instance.getFilesFromIDs( + queryResultIds, + ); + final results = []; + + final ignoredCollections = + CollectionsService.instance.getHiddenCollectionIds(); + final deletedEntries = []; + for (final result in queryResults) { + final file = filesMap[result.id]; + if (file != null && !ignoredCollections.contains(file.collectionID)) { + results.add(file); + } + if (file == null) { + deletedEntries.add(result.id); + } + } + + _logger.info(results.length.toString() + " results"); + + if (deletedEntries.isNotEmpty) { + unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); + } + + final matchingFileIDs = []; + for (EnteFile file in results) { + matchingFileIDs.add(file.uploadedFileID!); + } + + return matchingFileIDs; + } + void _addToQueue(EnteFile file) { if (!LocalSettings.instance.hasEnabledMagicSearch()) { return; diff --git a/mobile/lib/services/magic_cache_service.dart b/mobile/lib/services/magic_cache_service.dart new file mode 100644 index 0000000000..a379b0388a --- /dev/null +++ b/mobile/lib/services/magic_cache_service.dart @@ -0,0 +1,225 @@ +import "dart:async"; +import "dart:convert"; +import "dart:io"; + +import "package:logging/logging.dart"; +import "package:path_provider/path_provider.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/search_types.dart"; +import "package:photos/service_locator.dart"; +import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/services/search_service.dart"; +import "package:shared_preferences/shared_preferences.dart"; + +class MagicCache { + final String title; + final List fileUploadedIDs; + MagicCache(this.title, this.fileUploadedIDs); + + factory MagicCache.fromJson(Map json) { + return MagicCache( + json['title'], + List.from(json['fileUploadedIDs']), + ); + } + + Map toJson() { + return { + 'title': title, + 'fileUploadedIDs': fileUploadedIDs, + }; + } + + static String encodeListToJson(List magicCaches) { + final jsonList = magicCaches.map((cache) => cache.toJson()).toList(); + return jsonEncode(jsonList); + } + + static List decodeJsonToList(String jsonString) { + final jsonList = jsonDecode(jsonString) as List; + return jsonList.map((json) => MagicCache.fromJson(json)).toList(); + } +} + +extension MagicCacheServiceExtension on MagicCache { + Future toGenericSearchResult() async { + final allEnteFiles = await SearchService.instance.getAllFiles(); + final enteFilesInMagicCache = []; + for (EnteFile file in allEnteFiles) { + if (file.uploadedFileID != null && + fileUploadedIDs.contains(file.uploadedFileID as int)) { + enteFilesInMagicCache.add(file); + } + } + return GenericSearchResult( + ResultType.magic, + title, + enteFilesInMagicCache, + ); + } +} + +class MagicCacheService { + static const _lastMagicCacheUpdateTime = "last_magic_cache_update_time"; + static const _kMagicPromptsDataUrl = "https://discover.ente.io/v1.json"; + + /// Delay is for cache update to be done not during app init, during which a + /// lot of other things are happening. + static const _kCacheUpdateDelay = Duration(seconds: 10); + + late SharedPreferences _prefs; + final Logger _logger = Logger((MagicCacheService).toString()); + MagicCacheService._privateConstructor(); + + static final MagicCacheService instance = + MagicCacheService._privateConstructor(); + + void init(SharedPreferences preferences) { + _prefs = preferences; + _updateCacheIfTheTimeHasCome(); + } + + Future resetLastMagicCacheUpdateTime() async { + await _prefs.setInt( + _lastMagicCacheUpdateTime, + DateTime.now().millisecondsSinceEpoch, + ); + } + + int get lastMagicCacheUpdateTime { + return _prefs.getInt(_lastMagicCacheUpdateTime) ?? 0; + } + + Future _updateCacheIfTheTimeHasCome() async { + final jsonFile = await RemoteAssetsService.instance + .getAssetIfUpdated(_kMagicPromptsDataUrl); + if (jsonFile != null) { + Future.delayed(_kCacheUpdateDelay, () { + unawaited(_updateCache()); + }); + return; + } + if (lastMagicCacheUpdateTime < + DateTime.now() + .subtract(const Duration(days: 3)) + .millisecondsSinceEpoch) { + Future.delayed(_kCacheUpdateDelay, () { + unawaited(_updateCache()); + }); + } + } + + Future _getCachePath() async { + return (await getApplicationSupportDirectory()).path + "/cache/magic_cache"; + } + + Future> _getMatchingFileIDsForPromptData( + Map promptData, + ) async { + final result = await SemanticSearchService.instance.getMatchingFileIDs( + promptData["prompt"] as String, + promptData["minimumScore"] as double, + ); + + return result; + } + + Future _updateCache() async { + try { + _logger.info("updating magic cache"); + final magicPromptsData = await _loadMagicPrompts(); + final magicCaches = await nonEmptyMagicResults(magicPromptsData); + final file = File(await _getCachePath()); + if (!file.existsSync()) { + file.createSync(recursive: true); + } + file.writeAsBytesSync(MagicCache.encodeListToJson(magicCaches).codeUnits); + unawaited( + resetLastMagicCacheUpdateTime().onError((error, stackTrace) { + _logger.warning( + "Error resetting last magic cache update time", + error, + ); + }), + ); + } catch (e) { + _logger.info("Error updating magic cache", e); + } + } + + Future?> _getMagicCache() async { + final file = File(await _getCachePath()); + if (!file.existsSync()) { + _logger.info("No magic cache found"); + return null; + } + final jsonString = file.readAsStringSync(); + return MagicCache.decodeJsonToList(jsonString); + } + + Future clearMagicCache() async { + File(await _getCachePath()).deleteSync(); + } + + Future> getMagicGenericSearchResult() async { + try { + final magicCaches = await _getMagicCache(); + if (magicCaches == null) { + _logger.info("No magic cache found"); + return []; + } + + final List genericSearchResults = []; + for (MagicCache magicCache in magicCaches) { + final genericSearchResult = await magicCache.toGenericSearchResult(); + genericSearchResults.add(genericSearchResult); + } + return genericSearchResults; + } catch (e) { + _logger.info("Error getting magic generic search result", e); + return []; + } + } + + Future> _loadMagicPrompts() async { + final file = + await RemoteAssetsService.instance.getAsset(_kMagicPromptsDataUrl); + + final json = jsonDecode(await file.readAsString()); + return json["prompts"]; + } + + ///Returns random non-empty magic results from magicPromptsData + ///Length is capped at [limit], can be less than [limit] if there are not enough + ///non-empty results + Future> nonEmptyMagicResults( + List magicPromptsData, + ) async { + //Show all magic prompts to internal users for feedback on results + final limit = flagService.internalUser ? magicPromptsData.length : 4; + final results = []; + final randomIndexes = List.generate( + magicPromptsData.length, + (index) => index, + growable: false, + )..shuffle(); + for (final index in randomIndexes) { + final files = + await _getMatchingFileIDsForPromptData(magicPromptsData[index]); + if (files.isNotEmpty) { + results.add( + MagicCache( + magicPromptsData[index]["title"] as String, + files, + ), + ); + } + if (results.length >= limit) { + break; + } + } + return results; + } +} diff --git a/mobile/lib/services/remote_assets_service.dart b/mobile/lib/services/remote_assets_service.dart index 487f2f8b11..b9bed09b50 100644 --- a/mobile/lib/services/remote_assets_service.dart +++ b/mobile/lib/services/remote_assets_service.dart @@ -21,17 +21,46 @@ class RemoteAssetsService { Future getAsset(String remotePath, {bool refetch = false}) async { final path = await _getLocalPath(remotePath); final file = File(path); - if (await file.exists() && !refetch) { + if (file.existsSync() && !refetch) { _logger.info("Returning cached file for $remotePath"); return file; } else { final tempFile = File(path + ".temp"); await _downloadFile(remotePath, tempFile.path); - await tempFile.rename(path); + tempFile.renameSync(path); return File(path); } } + ///Returns asset if the remote asset is new compared to the local copy of it + Future getAssetIfUpdated(String remotePath) async { + try { + final path = await _getLocalPath(remotePath); + final file = File(path); + if (!file.existsSync()) { + final tempFile = File(path + ".temp"); + await _downloadFile(remotePath, tempFile.path); + tempFile.renameSync(path); + return File(path); + } else { + final existingFileSize = File(path).lengthSync(); + final tempFile = File(path + ".temp"); + await _downloadFile(remotePath, tempFile.path); + final newFileSize = tempFile.lengthSync(); + if (existingFileSize != newFileSize) { + tempFile.renameSync(path); + return File(path); + } else { + tempFile.deleteSync(); + return null; + } + } + } catch (e) { + _logger.warning("Error getting asset if updated", e); + return null; + } + } + Future hasAsset(String remotePath) async { final path = await _getLocalPath(remotePath); return File(path).exists(); @@ -60,8 +89,8 @@ class RemoteAssetsService { Future _downloadFile(String url, String savePath) async { _logger.info("Downloading " + url); final existingFile = File(savePath); - if (await existingFile.exists()) { - await existingFile.delete(); + if (existingFile.existsSync()) { + existingFile.deleteSync(); } await NetworkClient.instance.getDio().download( diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index e55684ed9b..2161e654b7 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -1,4 +1,3 @@ -import "dart:convert"; import "dart:math"; import "package:flutter/cupertino.dart"; @@ -33,13 +32,14 @@ import "package:photos/services/location_service.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; -import "package:photos/services/remote_assets_service.dart"; +import "package:photos/services/magic_cache_service.dart"; import "package:photos/states/location_screen_state.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; import "package:photos/ui/viewer/location/location_screen.dart"; import "package:photos/ui/viewer/people/cluster_page.dart"; import "package:photos/ui/viewer/people/people_page.dart"; import 'package:photos/utils/date_time_util.dart'; +import "package:photos/utils/local_settings.dart"; import "package:photos/utils/navigation_util.dart"; import 'package:tuple/tuple.dart'; @@ -49,9 +49,6 @@ class SearchService { final _logger = Logger((SearchService).toString()); final _collectionService = CollectionsService.instance; static const _maximumResultsLimit = 20; - static const _kMagicPromptsDataUrl = "https://discover.ente.io/v1.json"; - - var magicPromptsData = []; SearchService._privateConstructor(); @@ -63,17 +60,6 @@ class SearchService { _cachedFilesFuture = null; _cachedHiddenFilesFuture = null; }); - if (flagService.internalUser) { - _loadMagicPrompts(); - } - } - - Future _loadMagicPrompts() async { - final file = await RemoteAssetsService.instance - .getAsset(_kMagicPromptsDataUrl, refetch: true); - - final json = jsonDecode(await file.readAsString()); - magicPromptsData = json["prompts"]; } Set ignoreCollections() { @@ -192,26 +178,12 @@ class SearchService { } Future> getMagicSectionResutls() async { - if (!SemanticSearchService.instance.isMagicSearchEnabledAndReady()) { + if (LocalSettings.instance.hasEnabledMagicSearch() && + flagService.internalUser) { + return MagicCacheService.instance.getMagicGenericSearchResult(); + } else { return []; } - final searchResuts = []; - for (Map magicPrompt in magicPromptsData) { - final files = await SemanticSearchService.instance.getMatchingFiles( - magicPrompt["prompt"], - scoreThreshold: magicPrompt["minimumScore"], - ); - if (files.isNotEmpty) { - searchResuts.add( - GenericSearchResult( - ResultType.magic, - magicPrompt["title"], - files, - ), - ); - } - } - return searchResuts; } Future> getRandomMomentsSearchResults( diff --git a/mobile/lib/ui/viewer/search_tab/magic_section.dart b/mobile/lib/ui/viewer/search_tab/magic_section.dart index d088de92e5..b5ec65b02a 100644 --- a/mobile/lib/ui/viewer/search_tab/magic_section.dart +++ b/mobile/lib/ui/viewer/search_tab/magic_section.dart @@ -33,28 +33,6 @@ class _MagicSectionState extends State { super.initState(); _magicSearchResults = widget.magicSearchResults; - //At times, ml framework is not initialized when the search results are - //requested (widget.momentsSearchResults is empty) and is initialized - //(which fires MLFrameworkInitializationUpdateEvent with - //InitializationState.initialized) before initState of this widget is - //called. We do listen to MLFrameworkInitializationUpdateEvent and reload - //this widget but the event with InitializationState.initialized would have - //already been fired in the above case. - if (_magicSearchResults.isEmpty) { - SectionType.magic - .getData( - context, - limit: kSearchSectionLimit, - ) - .then((value) { - if (mounted) { - setState(() { - _magicSearchResults = value as List; - }); - } - }); - } - final streamsToListenTo = SectionType.magic.sectionUpdateEvents(); for (Stream stream in streamsToListenTo) { streamSubscriptions.add( @@ -84,7 +62,6 @@ class _MagicSectionState extends State { @override void didUpdateWidget(covariant MagicSection oldWidget) { super.didUpdateWidget(oldWidget); - //widget.magicSearch is empty when doing a hot reload if (widget.magicSearchResults.isNotEmpty) { _magicSearchResults = widget.magicSearchResults; } @@ -262,7 +239,7 @@ class MagicRecommendation extends StatelessWidget { ), ConstrainedBox( constraints: const BoxConstraints( - maxWidth: 76, + maxWidth: 88, ), child: Padding( padding: const EdgeInsets.only(