mirror of
https://github.com/ente-io/ente.git
synced 2025-08-10 00:12:04 +00:00
[mob][photos] Rename and delete lot of clip stuff
This commit is contained in:
parent
4cdbb0c128
commit
2d0cadc8c9
@ -5,7 +5,7 @@ import "package:flutter/foundation.dart";
|
|||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
import "package:onnxruntime/onnxruntime.dart";
|
import "package:onnxruntime/onnxruntime.dart";
|
||||||
import "package:photos/services/machine_learning/ml_model.dart";
|
import "package:photos/services/machine_learning/ml_model.dart";
|
||||||
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart';
|
import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart';
|
||||||
import "package:photos/services/remote_assets_service.dart";
|
import "package:photos/services/remote_assets_service.dart";
|
||||||
|
|
||||||
class ClipTextEncoder extends MlModel {
|
class ClipTextEncoder extends MlModel {
|
@ -1,156 +0,0 @@
|
|||||||
import "dart:async";
|
|
||||||
import "dart:io";
|
|
||||||
|
|
||||||
import "package:connectivity_plus/connectivity_plus.dart";
|
|
||||||
import "package:logging/logging.dart";
|
|
||||||
import "package:photos/core/errors.dart";
|
|
||||||
import "package:photos/core/event_bus.dart";
|
|
||||||
import "package:photos/events/event.dart";
|
|
||||||
import "package:photos/services/remote_assets_service.dart";
|
|
||||||
|
|
||||||
abstract class MLFramework {
|
|
||||||
static const kImageEncoderEnabled = true;
|
|
||||||
static const kMaximumRetrials = 3;
|
|
||||||
|
|
||||||
static final _logger = Logger("MLFramework");
|
|
||||||
|
|
||||||
final bool shouldDownloadOverMobileData;
|
|
||||||
final _initializationCompleter = Completer<void>();
|
|
||||||
|
|
||||||
InitializationState _state = InitializationState.notInitialized;
|
|
||||||
|
|
||||||
MLFramework(this.shouldDownloadOverMobileData) {
|
|
||||||
Connectivity()
|
|
||||||
.onConnectivityChanged
|
|
||||||
.listen((List<ConnectivityResult> result) async {
|
|
||||||
_logger.info("Connectivity changed to $result");
|
|
||||||
if (_state == InitializationState.waitingForNetwork &&
|
|
||||||
await _canDownload()) {
|
|
||||||
unawaited(init());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
InitializationState get initializationState => _state;
|
|
||||||
|
|
||||||
set _initState(InitializationState state) {
|
|
||||||
Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state));
|
|
||||||
_logger.info("Init state is $state");
|
|
||||||
_state = state;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the path of the Image Model hosted remotely
|
|
||||||
String getImageModelRemotePath();
|
|
||||||
|
|
||||||
/// Returns the path of the Text Model hosted remotely
|
|
||||||
String getTextModelRemotePath();
|
|
||||||
|
|
||||||
/// Loads the Image Model stored at [path] into the framework
|
|
||||||
Future<void> loadImageModel(String path);
|
|
||||||
|
|
||||||
/// Loads the Text Model stored at [path] into the framework
|
|
||||||
Future<void> loadTextModel(String path);
|
|
||||||
|
|
||||||
/// Returns the Image Embedding for a file stored at [imagePath]
|
|
||||||
Future<List<double>> getImageEmbedding(String imagePath);
|
|
||||||
|
|
||||||
/// Returns the Text Embedding for [text]
|
|
||||||
Future<List<double>> getTextEmbedding(String text);
|
|
||||||
|
|
||||||
/// Downloads the models from remote, caches them and loads them into the
|
|
||||||
/// framework. Override this method if you would like to control the
|
|
||||||
/// initialization. For eg. if you wish to load the model from `/assets`
|
|
||||||
/// instead of a CDN.
|
|
||||||
Future<void> init() async {
|
|
||||||
try {
|
|
||||||
_initState = InitializationState.initializing;
|
|
||||||
await Future.wait([_initImageModel(), _initTextModel()]);
|
|
||||||
} catch (e, s) {
|
|
||||||
_logger.warning(e, s);
|
|
||||||
if (e is WiFiUnavailableError) {
|
|
||||||
return _initializationCompleter.future;
|
|
||||||
} else {
|
|
||||||
rethrow;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_initState = InitializationState.initialized;
|
|
||||||
_initializationCompleter.complete();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases any resources held by the framework
|
|
||||||
Future<void> release() async {}
|
|
||||||
|
|
||||||
/// Returns the cosine similarity between [imageEmbedding] and [textEmbedding]
|
|
||||||
double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
|
|
||||||
assert(
|
|
||||||
imageEmbedding.length == textEmbedding.length,
|
|
||||||
"The two embeddings should have the same length",
|
|
||||||
);
|
|
||||||
double score = 0;
|
|
||||||
for (int index = 0; index < imageEmbedding.length; index++) {
|
|
||||||
score += imageEmbedding[index] * textEmbedding[index];
|
|
||||||
}
|
|
||||||
return score;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---
|
|
||||||
// Private methods
|
|
||||||
// ---
|
|
||||||
|
|
||||||
Future<void> _initImageModel() async {
|
|
||||||
if (!kImageEncoderEnabled) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
final imageModel = await _getModel(getImageModelRemotePath());
|
|
||||||
await loadImageModel(imageModel.path);
|
|
||||||
}
|
|
||||||
|
|
||||||
Future<void> _initTextModel() async {
|
|
||||||
final textModel = await _getModel(getTextModelRemotePath());
|
|
||||||
await loadTextModel(textModel.path);
|
|
||||||
}
|
|
||||||
|
|
||||||
Future<File> _getModel(
|
|
||||||
String url, {
|
|
||||||
int trialCount = 1,
|
|
||||||
}) async {
|
|
||||||
if (await RemoteAssetsService.instance.hasAsset(url)) {
|
|
||||||
return RemoteAssetsService.instance.getAsset(url);
|
|
||||||
}
|
|
||||||
if (!await _canDownload()) {
|
|
||||||
_initState = InitializationState.waitingForNetwork;
|
|
||||||
throw WiFiUnavailableError();
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
return RemoteAssetsService.instance.getAsset(url);
|
|
||||||
} catch (e, s) {
|
|
||||||
_logger.severe(e, s);
|
|
||||||
if (trialCount < kMaximumRetrials) {
|
|
||||||
return _getModel(url, trialCount: trialCount + 1);
|
|
||||||
} else {
|
|
||||||
rethrow;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Future<bool> _canDownload() async {
|
|
||||||
final List<ConnectivityResult> connections =
|
|
||||||
await (Connectivity().checkConnectivity());
|
|
||||||
final bool isConnectedToMobile =
|
|
||||||
connections.contains(ConnectivityResult.mobile);
|
|
||||||
return !isConnectedToMobile || shouldDownloadOverMobileData;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class MLFrameworkInitializationUpdateEvent extends Event {
|
|
||||||
final InitializationState state;
|
|
||||||
|
|
||||||
MLFrameworkInitializationUpdateEvent(this.state);
|
|
||||||
}
|
|
||||||
|
|
||||||
enum InitializationState {
|
|
||||||
notInitialized,
|
|
||||||
waitingForNetwork,
|
|
||||||
initializing,
|
|
||||||
initialized,
|
|
||||||
}
|
|
@ -1,127 +0,0 @@
|
|||||||
import "package:computer/computer.dart";
|
|
||||||
import "package:logging/logging.dart";
|
|
||||||
import "package:onnxruntime/onnxruntime.dart";
|
|
||||||
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
|
|
||||||
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart';
|
|
||||||
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart';
|
|
||||||
import "package:photos/utils/image_isolate.dart";
|
|
||||||
|
|
||||||
class ONNX extends MLFramework {
|
|
||||||
static const kModelBucketEndpoint = "https://models.ente.io/";
|
|
||||||
static const kImageModel = "clip-image-vit-32-float32.onnx";
|
|
||||||
// static const kTextModel = "clip-text-vit-32-uint8.onnx"; // TODO: check later whether to revert back or not
|
|
||||||
static const kTextModel = "clip-text-vit-32-float32-int32.onnx";
|
|
||||||
|
|
||||||
final _computer = Computer.shared();
|
|
||||||
final _logger = Logger("ONNX");
|
|
||||||
final _clipImage = OnnxImageEncoder();
|
|
||||||
final _clipText = OnnxTextEncoder();
|
|
||||||
int _textEncoderAddress = 0;
|
|
||||||
int _imageEncoderAddress = 0;
|
|
||||||
|
|
||||||
ONNX(super.shouldDownloadOverMobileData);
|
|
||||||
|
|
||||||
@override
|
|
||||||
String getImageModelRemotePath() {
|
|
||||||
return kModelBucketEndpoint + kImageModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
String getTextModelRemotePath() {
|
|
||||||
return kModelBucketEndpoint + kTextModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<void> init() async {
|
|
||||||
await _computer.compute(initOrtEnv);
|
|
||||||
await super.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<void> loadImageModel(String path) async {
|
|
||||||
final startTime = DateTime.now();
|
|
||||||
_imageEncoderAddress = await _computer.compute(
|
|
||||||
_clipImage.loadModel,
|
|
||||||
param: {
|
|
||||||
"imageModelPath": path,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
final endTime = DateTime.now();
|
|
||||||
_logger.info(
|
|
||||||
"Loading image model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<void> loadTextModel(String path) async {
|
|
||||||
_logger.info('loadTextModel called');
|
|
||||||
final startTime = DateTime.now();
|
|
||||||
await _clipText.initTokenizer();
|
|
||||||
_textEncoderAddress = await _computer.compute(
|
|
||||||
_clipText.loadModel,
|
|
||||||
param: {
|
|
||||||
"textModelPath": path,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
final endTime = DateTime.now();
|
|
||||||
_logger.info(
|
|
||||||
"Loading text model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<List<double>> getImageEmbedding(String imagePath) async {
|
|
||||||
_logger.info('getImageEmbedding called');
|
|
||||||
try {
|
|
||||||
final startTime = DateTime.now();
|
|
||||||
// TODO: properly integrate with other ml later (FaceMlService)
|
|
||||||
final result = await ImageIsolate.instance.inferClipImageEmbedding(
|
|
||||||
imagePath,
|
|
||||||
_imageEncoderAddress,
|
|
||||||
);
|
|
||||||
final endTime = DateTime.now();
|
|
||||||
_logger.info(
|
|
||||||
"getImageEmbedding done in ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
|
|
||||||
);
|
|
||||||
return result;
|
|
||||||
} catch (e, s) {
|
|
||||||
_logger.severe(e, s);
|
|
||||||
rethrow;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<List<double>> getTextEmbedding(String text) async {
|
|
||||||
try {
|
|
||||||
final startTime = DateTime.now();
|
|
||||||
final result = await _computer.compute(
|
|
||||||
_clipText.infer,
|
|
||||||
param: {
|
|
||||||
"text": text,
|
|
||||||
"address": _textEncoderAddress,
|
|
||||||
},
|
|
||||||
taskName: "createTextEmbedding",
|
|
||||||
) as List<double>;
|
|
||||||
final endTime = DateTime.now();
|
|
||||||
_logger.info(
|
|
||||||
"createTextEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
|
|
||||||
);
|
|
||||||
return result;
|
|
||||||
} catch (e, s) {
|
|
||||||
_logger.severe(e, s);
|
|
||||||
rethrow;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
|
||||||
Future<void> release() async {
|
|
||||||
final session = OrtSession.fromAddress(_textEncoderAddress);
|
|
||||||
session.release();
|
|
||||||
OrtEnv.instance.release();
|
|
||||||
_logger.info('Released');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void initOrtEnv() async {
|
|
||||||
OrtEnv.instance.init();
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user