mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 15:30:40 +00:00
[mob] Use custom plugin for clip image encoding
This commit is contained in:
parent
519d7a9a5e
commit
08f846c315
@ -1,6 +1,6 @@
|
|||||||
import "dart:async";
|
import "dart:async";
|
||||||
import "dart:developer" as dev show log;
|
import "dart:developer" as dev show log;
|
||||||
import "dart:io" show File;
|
import "dart:io" show File, Platform;
|
||||||
import "dart:isolate";
|
import "dart:isolate";
|
||||||
import "dart:math" show min;
|
import "dart:math" show min;
|
||||||
import "dart:typed_data" show Uint8List, ByteData;
|
import "dart:typed_data" show Uint8List, ByteData;
|
||||||
@ -838,6 +838,7 @@ class MLService {
|
|||||||
image,
|
image,
|
||||||
imageByteData,
|
imageByteData,
|
||||||
clipImageAddress,
|
clipImageAddress,
|
||||||
|
useEntePlugin: Platform.isAndroid,
|
||||||
);
|
);
|
||||||
result.clip = clipResult;
|
result.clip = clipResult;
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
import "dart:typed_data" show ByteData;
|
import "dart:typed_data";
|
||||||
import "dart:ui" show Image;
|
import "dart:ui" show Image;
|
||||||
|
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
|
import "package:onnx_dart/onnx_dart.dart";
|
||||||
import "package:onnxruntime/onnxruntime.dart";
|
import "package:onnxruntime/onnxruntime.dart";
|
||||||
|
import "package:photos/extensions/stop_watch.dart";
|
||||||
import "package:photos/services/machine_learning/ml_model.dart";
|
import "package:photos/services/machine_learning/ml_model.dart";
|
||||||
import "package:photos/utils/image_ml_util.dart";
|
import "package:photos/utils/image_ml_util.dart";
|
||||||
import "package:photos/utils/ml_util.dart";
|
import "package:photos/utils/ml_util.dart";
|
||||||
|
|
||||||
class ClipImageEncoder extends MlModel {
|
class ClipImageEncoder extends MlModel {
|
||||||
static const kRemoteBucketModelPath = "clip-image-vit-32-float32.onnx";
|
static const kRemoteBucketModelPath = "clip-image-vit-32-float32.onnx";
|
||||||
|
static const _modelName = "ClipImageEncoder";
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||||
@ -18,7 +21,7 @@ class ClipImageEncoder extends MlModel {
|
|||||||
static final _logger = Logger('ClipImageEncoder');
|
static final _logger = Logger('ClipImageEncoder');
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String get modelName => "ClipImageEncoder";
|
String get modelName => _modelName;
|
||||||
|
|
||||||
// Singleton pattern
|
// Singleton pattern
|
||||||
ClipImageEncoder._privateConstructor();
|
ClipImageEncoder._privateConstructor();
|
||||||
@ -28,10 +31,27 @@ class ClipImageEncoder extends MlModel {
|
|||||||
static Future<List<double>> predict(
|
static Future<List<double>> predict(
|
||||||
Image image,
|
Image image,
|
||||||
ByteData imageByteData,
|
ByteData imageByteData,
|
||||||
int sessionAddress,
|
int sessionAddress, {
|
||||||
) async {
|
bool useEntePlugin = false,
|
||||||
|
}) async {
|
||||||
|
final w = EnteWatch("ClipImageEncoder.predict")..start();
|
||||||
final inputList = await preprocessImageClip(image, imageByteData);
|
final inputList = await preprocessImageClip(image, imageByteData);
|
||||||
|
w.log("preprocessImageClip");
|
||||||
|
if (useEntePlugin) {
|
||||||
|
final result = await _runEntePlugin(inputList);
|
||||||
|
w.stopWithLog("done");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
final result = _runFFIBasedPredict(inputList, sessionAddress);
|
||||||
|
w.stopWithLog("done");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static List<double> _runFFIBasedPredict(
|
||||||
|
Float32List inputList,
|
||||||
|
int sessionAddress,
|
||||||
|
) {
|
||||||
|
final w = EnteWatch("ClipImageEncoder._runFFIBasedPredict")..start();
|
||||||
final inputOrt =
|
final inputOrt =
|
||||||
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]);
|
OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 224, 224]);
|
||||||
final inputs = {'input': inputOrt};
|
final inputs = {'input': inputOrt};
|
||||||
@ -39,9 +59,23 @@ class ClipImageEncoder extends MlModel {
|
|||||||
final runOptions = OrtRunOptions();
|
final runOptions = OrtRunOptions();
|
||||||
final outputs = session.run(runOptions, inputs);
|
final outputs = session.run(runOptions, inputs);
|
||||||
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
||||||
|
|
||||||
normalizeEmbedding(embedding);
|
normalizeEmbedding(embedding);
|
||||||
|
w.stopWithLog("done");
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Future<List<double>> _runEntePlugin(
|
||||||
|
Float32List inputImageList,
|
||||||
|
) async {
|
||||||
|
final w = EnteWatch("ClipImageEncoder._runEntePlugin")..start();
|
||||||
|
final OnnxDart plugin = OnnxDart();
|
||||||
|
final result = await plugin.predict(
|
||||||
|
inputImageList,
|
||||||
|
_modelName,
|
||||||
|
);
|
||||||
|
final List<double> embedding = result!.sublist(0, 512);
|
||||||
|
normalizeEmbedding(embedding);
|
||||||
|
w.stopWithLog("done");
|
||||||
return embedding;
|
return embedding;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -334,10 +334,15 @@ class SemanticSearchService {
|
|||||||
int enteFileID,
|
int enteFileID,
|
||||||
Image image,
|
Image image,
|
||||||
ByteData imageByteData,
|
ByteData imageByteData,
|
||||||
int clipImageAddress,
|
int clipImageAddress, {
|
||||||
) async {
|
bool useEntePlugin = false,
|
||||||
final embedding =
|
}) async {
|
||||||
await ClipImageEncoder.predict(image, imageByteData, clipImageAddress);
|
final embedding = await ClipImageEncoder.predict(
|
||||||
|
image,
|
||||||
|
imageByteData,
|
||||||
|
clipImageAddress,
|
||||||
|
useEntePlugin: useEntePlugin,
|
||||||
|
);
|
||||||
final clipResult = ClipResult(fileID: enteFileID, embedding: embedding);
|
final clipResult = ClipResult(fileID: enteFileID, embedding: embedding);
|
||||||
|
|
||||||
return clipResult;
|
return clipResult;
|
||||||
|
@ -168,6 +168,11 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|||||||
inputTensorShape[1] = 112
|
inputTensorShape[1] = 112
|
||||||
inputTensorShape[2] = 112
|
inputTensorShape[2] = 112
|
||||||
inputTensorShape[3] = 3
|
inputTensorShape[3] = 3
|
||||||
|
} else if(modelType == ModelType.ClipImageEncoder) {
|
||||||
|
inputTensorShape[0] = 1
|
||||||
|
inputTensorShape[1] = 3
|
||||||
|
inputTensorShape[2] = 224
|
||||||
|
inputTensorShape[3] = 224
|
||||||
}
|
}
|
||||||
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
|
val inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputTensorShape)
|
||||||
val inputs = mutableMapOf<String, OnnxTensor>()
|
val inputs = mutableMapOf<String, OnnxTensor>()
|
||||||
@ -178,14 +183,14 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler {
|
|||||||
}
|
}
|
||||||
val outputs = session.run(inputs)
|
val outputs = session.run(inputs)
|
||||||
Log.d(TAG, "Output shape: ${outputs.size()}")
|
Log.d(TAG, "Output shape: ${outputs.size()}")
|
||||||
if (modelType == ModelType.MobileFaceNet) {
|
if (modelType == ModelType.YOLOv5Face) {
|
||||||
val outputTensor = (outputs[0].value as Array<FloatArray>)
|
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
|
||||||
val flatList = outputTensor.flattenToFloatArray()
|
val flatList = outputTensor.flattenToFloatArray()
|
||||||
withContext(Dispatchers.Main) {
|
withContext(Dispatchers.Main) {
|
||||||
result.success(flatList)
|
result.success(flatList)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
val outputTensor = (outputs[0].value as Array<Array<FloatArray>>).get(0)
|
val outputTensor = (outputs[0].value as Array<FloatArray>)
|
||||||
val flatList = outputTensor.flattenToFloatArray()
|
val flatList = outputTensor.flattenToFloatArray()
|
||||||
withContext(Dispatchers.Main) {
|
withContext(Dispatchers.Main) {
|
||||||
result.success(flatList)
|
result.success(flatList)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user