[mob] Use custom plugin for clip image encoding

This commit is contained in:
Neeraj Gupta 2024-07-10 18:22:04 +05:30
parent 519d7a9a5e
commit 08f846c315
4 changed files with 58 additions and 13 deletions

View File

@ -1,6 +1,6 @@
import "dart:async";
import "dart:developer" as dev show log;
import "dart:io" show File;
import "dart:io" show File, Platform;
import "dart:isolate";
import "dart:math" show min;
import "dart:typed_data" show Uint8List, ByteData;
@ -838,6 +838,7 @@ class MLService {
image,
imageByteData,
clipImageAddress,
useEntePlugin: Platform.isAndroid,
);
result.clip = clipResult;
}

View File

@ -1,14 +1,17 @@
import "dart:typed_data" show ByteData;
import "dart:typed_data";
import "dart:ui" show Image;
import "package:logging/logging.dart";
import "package:onnx_dart/onnx_dart.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/ml_util.dart";
class ClipImageEncoder extends MlModel {
static const kRemoteBucketModelPath = "clip-image-vit-32-float32.onnx";
static const _modelName = "ClipImageEncoder";
@override
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
@ -18,7 +21,7 @@ class ClipImageEncoder extends MlModel {
static final _logger = Logger('ClipImageEncoder');
@override
String get modelName => "ClipImageEncoder";
String get modelName => _modelName;
// Singleton pattern
ClipImageEncoder._privateConstructor();
@ -28,10 +31,27 @@ class ClipImageEncoder extends MlModel {
static Future<List<double>> predict(
Image image,
ByteData imageByteData,
int sessionAddress,
) async {
int sessionAddress, {
bool useEntePlugin = false,
}) async {
final w = EnteWatch("ClipImageEncoder.predict")..start();
final inputList = await preprocessImageClip(image, imageByteData);
w.log("preprocessImageClip");
if (useEntePlugin) {
final result = await _runEntePlugin(inputList);
w.stopWithLog("done");
return result;
}
final result = _runFFIBasedPredict(inputList, sessionAddress);
w.stopWithLog("done");
return result;
}
static List<double> _runFFIBasedPredict(
Float32List inputList,
int sessionAddress,
) {
final w = EnteWatch("ClipImageEncoder._runFFIBasedPredict")..start();
final inputOrt =
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]);
final inputs = {'input': inputOrt};
@ -39,9 +59,23 @@ class ClipImageEncoder extends MlModel {
final runOptions = OrtRunOptions();
final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];
normalizeEmbedding(embedding);
w.stopWithLog("done");
return embedding;
}
static Future<List<double>> _runEntePlugin(
Float32List inputImageList,
) async {
final w = EnteWatch("ClipImageEncoder._runEntePlugin")..start();
final OnnxDart plugin = OnnxDart();
final result = await plugin.predict(
inputImageList,
_modelName,
);
final List<double> embedding = result!.sublist(0, 512);
normalizeEmbedding(embedding);
w.stopWithLog("done");
return embedding;
}
}

View File

@ -334,10 +334,15 @@ class SemanticSearchService {
int enteFileID,
Image image,
ByteData imageByteData,
int clipImageAddress,
) async {
final embedding =
await ClipImageEncoder.predict(image, imageByteData, clipImageAddress);
int clipImageAddress, {
bool useEntePlugin = false,
}) async {
final embedding = await ClipImageEncoder.predict(
image,
imageByteData,
clipImageAddress,
useEntePlugin: useEntePlugin,
);
final clipResult = ClipResult(fileID: enteFileID, embedding: embedding);
return clipResult;

View File

@ -168,6 +168,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
inputTensorShape[1] = 112
inputTensorShape[2] = 112
inputTensorShape[3] = 3
} else if(modelType == ModelType.ClipImageEncoder) {
inputTensorShape[0] = 1
inputTensorShape[1] = 3
inputTensorShape[2] = 224
inputTensorShape[3] = 224
}
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
val inputs = mutableMapOf<String, OnnxTensor>()
@ -178,14 +183,14 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
}
val outputs = session.run(inputs)
Log.d(TAG, "Output shape: ${outputs.size()}")
if (modelType == ModelType.MobileFaceNet) {
val outputTensor = (outputs[0].value as Array<FloatArray>)
if (modelType == ModelType.YOLOv5Face) {
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
val flatList = outputTensor.flattenToFloatArray()
withContext(Dispatchers.Main) {
result.success(flatList)
}
} else {
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
val outputTensor = (outputs[0].value as Array<FloatArray>)
val flatList = outputTensor.flattenToFloatArray()
withContext(Dispatchers.Main) {
result.success(flatList)