mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[mob] Use custom plugin for clip image encoding
This commit is contained in:
parent
519d7a9a5e
commit
08f846c315
@ -1,6 +1,6 @@
|
||||
import "dart:async";
|
||||
import "dart:developer" as dev show log;
|
||||
import "dart:io" show File;
|
||||
import "dart:io" show File, Platform;
|
||||
import "dart:isolate";
|
||||
import "dart:math" show min;
|
||||
import "dart:typed_data" show Uint8List, ByteData;
|
||||
@ -838,6 +838,7 @@ class MLService {
|
||||
image,
|
||||
imageByteData,
|
||||
clipImageAddress,
|
||||
useEntePlugin: Platform.isAndroid,
|
||||
);
|
||||
result.clip = clipResult;
|
||||
}
|
||||
|
@ -1,14 +1,17 @@
|
||||
import "dart:typed_data" show ByteData;
|
||||
import "dart:typed_data";
|
||||
import "dart:ui" show Image;
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:onnx_dart/onnx_dart.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/extensions/stop_watch.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import "package:photos/utils/image_ml_util.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
|
||||
class ClipImageEncoder extends MlModel {
|
||||
static const kRemoteBucketModelPath = "clip-image-vit-32-float32.onnx";
|
||||
static const _modelName = "ClipImageEncoder";
|
||||
|
||||
@override
|
||||
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
@ -18,7 +21,7 @@ class ClipImageEncoder extends MlModel {
|
||||
static final _logger = Logger('ClipImageEncoder');
|
||||
|
||||
@override
|
||||
String get modelName => "ClipImageEncoder";
|
||||
String get modelName => _modelName;
|
||||
|
||||
// Singleton pattern
|
||||
ClipImageEncoder._privateConstructor();
|
||||
@ -28,10 +31,27 @@ class ClipImageEncoder extends MlModel {
|
||||
static Future<List<double>> predict(
|
||||
Image image,
|
||||
ByteData imageByteData,
|
||||
int sessionAddress,
|
||||
) async {
|
||||
int sessionAddress, {
|
||||
bool useEntePlugin = false,
|
||||
}) async {
|
||||
final w = EnteWatch("ClipImageEncoder.predict")..start();
|
||||
final inputList = await preprocessImageClip(image, imageByteData);
|
||||
w.log("preprocessImageClip");
|
||||
if (useEntePlugin) {
|
||||
final result = await _runEntePlugin(inputList);
|
||||
w.stopWithLog("done");
|
||||
return result;
|
||||
}
|
||||
final result = _runFFIBasedPredict(inputList, sessionAddress);
|
||||
w.stopWithLog("done");
|
||||
return result;
|
||||
}
|
||||
|
||||
static List<double> _runFFIBasedPredict(
|
||||
Float32List inputList,
|
||||
int sessionAddress,
|
||||
) {
|
||||
final w = EnteWatch("ClipImageEncoder._runFFIBasedPredict")..start();
|
||||
final inputOrt =
|
||||
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]);
|
||||
final inputs = {'input': inputOrt};
|
||||
@ -39,9 +59,23 @@ class ClipImageEncoder extends MlModel {
|
||||
final runOptions = OrtRunOptions();
|
||||
final outputs = session.run(runOptions, inputs);
|
||||
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
||||
|
||||
normalizeEmbedding(embedding);
|
||||
w.stopWithLog("done");
|
||||
return embedding;
|
||||
}
|
||||
|
||||
static Future<List<double>> _runEntePlugin(
|
||||
Float32List inputImageList,
|
||||
) async {
|
||||
final w = EnteWatch("ClipImageEncoder._runEntePlugin")..start();
|
||||
final OnnxDart plugin = OnnxDart();
|
||||
final result = await plugin.predict(
|
||||
inputImageList,
|
||||
_modelName,
|
||||
);
|
||||
final List<double> embedding = result!.sublist(0, 512);
|
||||
normalizeEmbedding(embedding);
|
||||
w.stopWithLog("done");
|
||||
return embedding;
|
||||
}
|
||||
}
|
||||
|
@ -334,10 +334,15 @@ class SemanticSearchService {
|
||||
int enteFileID,
|
||||
Image image,
|
||||
ByteData imageByteData,
|
||||
int clipImageAddress,
|
||||
) async {
|
||||
final embedding =
|
||||
await ClipImageEncoder.predict(image, imageByteData, clipImageAddress);
|
||||
int clipImageAddress, {
|
||||
bool useEntePlugin = false,
|
||||
}) async {
|
||||
final embedding = await ClipImageEncoder.predict(
|
||||
image,
|
||||
imageByteData,
|
||||
clipImageAddress,
|
||||
useEntePlugin: useEntePlugin,
|
||||
);
|
||||
final clipResult = ClipResult(fileID: enteFileID, embedding: embedding);
|
||||
|
||||
return clipResult;
|
||||
|
@ -168,6 +168,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
inputTensorShape[1] = 112
|
||||
inputTensorShape[2] = 112
|
||||
inputTensorShape[3] = 3
|
||||
} else if(modelType == ModelType.ClipImageEncoder) {
|
||||
inputTensorShape[0] = 1
|
||||
inputTensorShape[1] = 3
|
||||
inputTensorShape[2] = 224
|
||||
inputTensorShape[3] = 224
|
||||
}
|
||||
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
|
||||
val inputs = mutableMapOf<String, OnnxTensor>()
|
||||
@ -178,14 +183,14 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
||||
}
|
||||
val outputs = session.run(inputs)
|
||||
Log.d(TAG, "Output shape: ${outputs.size()}")
|
||||
if (modelType == ModelType.MobileFaceNet) {
|
||||
val outputTensor = (outputs[0].value as Array<FloatArray>)
|
||||
if (modelType == ModelType.YOLOv5Face) {
|
||||
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
|
||||
val flatList = outputTensor.flattenToFloatArray()
|
||||
withContext(Dispatchers.Main) {
|
||||
result.success(flatList)
|
||||
}
|
||||
} else {
|
||||
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
|
||||
val outputTensor = (outputs[0].value as Array<FloatArray>)
|
||||
val flatList = outputTensor.flattenToFloatArray()
|
||||
withContext(Dispatchers.Main) {
|
||||
result.success(flatList)
|
||||
|
Loading…
x
Reference in New Issue
Block a user