diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart similarity index 100% rename from mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart rename to mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart similarity index 98% rename from mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart rename to mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart index 6c35f49450..e6db45ff3c 100644 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_encoder.dart @@ -5,7 +5,7 @@ import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/ml_model.dart"; -import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart'; +import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart'; import "package:photos/services/remote_assets_service.dart"; class ClipTextEncoder extends MlModel { diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart similarity index 100% rename from mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart rename to mobile/lib/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart deleted file mode 100644 index d3736d7680..0000000000 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart +++ /dev/null @@ -1,156 +0,0 @@ -import "dart:async"; -import "dart:io"; - -import "package:connectivity_plus/connectivity_plus.dart"; -import "package:logging/logging.dart"; -import "package:photos/core/errors.dart"; -import "package:photos/core/event_bus.dart"; -import "package:photos/events/event.dart"; -import "package:photos/services/remote_assets_service.dart"; - -abstract class MLFramework { - static const kImageEncoderEnabled = true; - static const kMaximumRetrials = 3; - - static final _logger = Logger("MLFramework"); - - final bool shouldDownloadOverMobileData; - final _initializationCompleter = Completer(); - - InitializationState _state = InitializationState.notInitialized; - - MLFramework(this.shouldDownloadOverMobileData) { - Connectivity() - .onConnectivityChanged - .listen((List result) async { - _logger.info("Connectivity changed to $result"); - if (_state == InitializationState.waitingForNetwork && - await _canDownload()) { - unawaited(init()); - } - }); - } - - InitializationState get initializationState => _state; - - set _initState(InitializationState state) { - Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state)); - _logger.info("Init state is $state"); - _state = state; - } - - /// Returns the path of the Image Model hosted remotely - String getImageModelRemotePath(); - - /// Returns the path of the Text Model hosted remotely - String getTextModelRemotePath(); - - /// Loads the Image Model stored at [path] into the framework - Future loadImageModel(String path); - - /// Loads the Text Model stored at [path] into the framework - Future loadTextModel(String path); - - /// Returns the Image Embedding for a file stored at [imagePath] - Future> getImageEmbedding(String imagePath); - - /// Returns the Text Embedding for [text] - Future> getTextEmbedding(String text); - - /// Downloads the models from remote, caches them and loads them into the - /// framework. Override this method if you would like to control the - /// initialization. For eg. if you wish to load the model from `/assets` - /// instead of a CDN. - Future init() async { - try { - _initState = InitializationState.initializing; - await Future.wait([_initImageModel(), _initTextModel()]); - } catch (e, s) { - _logger.warning(e, s); - if (e is WiFiUnavailableError) { - return _initializationCompleter.future; - } else { - rethrow; - } - } - _initState = InitializationState.initialized; - _initializationCompleter.complete(); - } - - // Releases any resources held by the framework - Future release() async {} - - /// Returns the cosine similarity between [imageEmbedding] and [textEmbedding] - double computeScore(List imageEmbedding, List textEmbedding) { - assert( - imageEmbedding.length == textEmbedding.length, - "The two embeddings should have the same length", - ); - double score = 0; - for (int index = 0; index < imageEmbedding.length; index++) { - score += imageEmbedding[index] * textEmbedding[index]; - } - return score; - } - - // --- - // Private methods - // --- - - Future _initImageModel() async { - if (!kImageEncoderEnabled) { - return; - } - final imageModel = await _getModel(getImageModelRemotePath()); - await loadImageModel(imageModel.path); - } - - Future _initTextModel() async { - final textModel = await _getModel(getTextModelRemotePath()); - await loadTextModel(textModel.path); - } - - Future _getModel( - String url, { - int trialCount = 1, - }) async { - if (await RemoteAssetsService.instance.hasAsset(url)) { - return RemoteAssetsService.instance.getAsset(url); - } - if (!await _canDownload()) { - _initState = InitializationState.waitingForNetwork; - throw WiFiUnavailableError(); - } - try { - return RemoteAssetsService.instance.getAsset(url); - } catch (e, s) { - _logger.severe(e, s); - if (trialCount < kMaximumRetrials) { - return _getModel(url, trialCount: trialCount + 1); - } else { - rethrow; - } - } - } - - Future _canDownload() async { - final List connections = - await (Connectivity().checkConnectivity()); - final bool isConnectedToMobile = - connections.contains(ConnectivityResult.mobile); - return !isConnectedToMobile || shouldDownloadOverMobileData; - } -} - -class MLFrameworkInitializationUpdateEvent extends Event { - final InitializationState state; - - MLFrameworkInitializationUpdateEvent(this.state); -} - -enum InitializationState { - notInitialized, - waitingForNetwork, - initializing, - initialized, -} diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart deleted file mode 100644 index 44d6cabc49..0000000000 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart +++ /dev/null @@ -1,127 +0,0 @@ -import "package:computer/computer.dart"; -import "package:logging/logging.dart"; -import "package:onnxruntime/onnxruntime.dart"; -import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart'; -import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart'; -import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart'; -import "package:photos/utils/image_isolate.dart"; - -class ONNX extends MLFramework { - static const kModelBucketEndpoint = "https://models.ente.io/"; - static const kImageModel = "clip-image-vit-32-float32.onnx"; - // static const kTextModel = "clip-text-vit-32-uint8.onnx"; // TODO: check later whether to revert back or not - static const kTextModel = "clip-text-vit-32-float32-int32.onnx"; - - final _computer = Computer.shared(); - final _logger = Logger("ONNX"); - final _clipImage = OnnxImageEncoder(); - final _clipText = OnnxTextEncoder(); - int _textEncoderAddress = 0; - int _imageEncoderAddress = 0; - - ONNX(super.shouldDownloadOverMobileData); - - @override - String getImageModelRemotePath() { - return kModelBucketEndpoint + kImageModel; - } - - @override - String getTextModelRemotePath() { - return kModelBucketEndpoint + kTextModel; - } - - @override - Future init() async { - await _computer.compute(initOrtEnv); - await super.init(); - } - - @override - Future loadImageModel(String path) async { - final startTime = DateTime.now(); - _imageEncoderAddress = await _computer.compute( - _clipImage.loadModel, - param: { - "imageModelPath": path, - }, - ); - final endTime = DateTime.now(); - _logger.info( - "Loading image model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", - ); - } - - @override - Future loadTextModel(String path) async { - _logger.info('loadTextModel called'); - final startTime = DateTime.now(); - await _clipText.initTokenizer(); - _textEncoderAddress = await _computer.compute( - _clipText.loadModel, - param: { - "textModelPath": path, - }, - ); - final endTime = DateTime.now(); - _logger.info( - "Loading text model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", - ); - } - - @override - Future> getImageEmbedding(String imagePath) async { - _logger.info('getImageEmbedding called'); - try { - final startTime = DateTime.now(); - // TODO: properly integrate with other ml later (FaceMlService) - final result = await ImageIsolate.instance.inferClipImageEmbedding( - imagePath, - _imageEncoderAddress, - ); - final endTime = DateTime.now(); - _logger.info( - "getImageEmbedding done in ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", - ); - return result; - } catch (e, s) { - _logger.severe(e, s); - rethrow; - } - } - - @override - Future> getTextEmbedding(String text) async { - try { - final startTime = DateTime.now(); - final result = await _computer.compute( - _clipText.infer, - param: { - "text": text, - "address": _textEncoderAddress, - }, - taskName: "createTextEmbedding", - ) as List; - final endTime = DateTime.now(); - _logger.info( - "createTextEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", - ); - return result; - } catch (e, s) { - _logger.severe(e, s); - rethrow; - } - } - - @override - Future release() async { - final session = OrtSession.fromAddress(_textEncoderAddress); - session.release(); - OrtEnv.instance.release(); - _logger.info('Released'); - } -} - -void initOrtEnv() async { - OrtEnv.instance.init(); -}