[mob][photos] Separate out tokenizer init

This commit is contained in:
laurenspriem 2024-07-25 23:28:17 +02:00
parent 18a5f4d212
commit a5b47f16a9
3 changed files with 47 additions and 32 deletions

View File

@ -8,6 +8,7 @@ import "package:logging/logging.dart";
import "package:photos/face/model/box.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_encoder.dart";
import "package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:synchronized/synchronized.dart";
@ -15,6 +16,7 @@ import "package:synchronized/synchronized.dart";
enum MLComputerOperation {
generateFaceThumbnails,
loadModel,
initializeClipTokenizer,
runClipText,
}
@ -23,6 +25,7 @@ class MLComputerIsolate {
final _initLock = Lock();
final _functionLock = Lock();
final _initModelLock = Lock();
late ReceivePort _receivePort = ReceivePort();
late SendPort _mainSendPort;
@ -96,6 +99,11 @@ class MLComputerIsolate {
);
sendPort.send(address);
break;
case MLComputerOperation.initializeClipTokenizer:
final vocabPath = args["vocabPath"] as String;
await ClipTextTokenizer.instance.init(vocabPath);
sendPort.send(true);
break;
case MLComputerOperation.runClipText:
final textEmbedding = await ClipTextEncoder.predict(args);
sendPort.send(List.from(textEmbedding, growable: false));
@ -160,17 +168,12 @@ class MLComputerIsolate {
try {
await _ensureLoadedClipTextModel();
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
final String tokenizerRemotePath =
ClipTextEncoder.instance.vocabRemotePath;
final String tokenizerVocabPath =
await RemoteAssetsService.instance.getAssetPath(tokenizerRemotePath);
final textEmbedding = await _runInIsolate(
(
MLComputerOperation.runClipText,
{
"text": query,
"address": clipAddress,
"vocabPath": tokenizerVocabPath,
}
),
) as List<double>;
@ -182,25 +185,40 @@ class MLComputerIsolate {
}
Future<void> _ensureLoadedClipTextModel() async {
if (ClipTextEncoder.instance.isInitialized) return;
try {
final String modelName = ClipTextEncoder.instance.modelName;
final String remotePath = ClipTextEncoder.instance.modelRemotePath;
final String modelPath =
await RemoteAssetsService.instance.getAssetPath(remotePath);
final address = await _runInIsolate(
(
MLComputerOperation.loadModel,
{
'modelName': modelName,
'modelPath': modelPath,
},
),
) as int;
ClipTextEncoder.instance.storeSessionAddress(address);
} catch (e, s) {
_logger.severe("Could not load clip text model in MLComputer", e, s);
rethrow;
}
return _initModelLock.synchronized(() async {
if (ClipTextEncoder.instance.isInitialized) return;
try {
// Initialize ClipText tokenizer
final String tokenizerRemotePath =
ClipTextEncoder.instance.vocabRemotePath;
final String tokenizerVocabPath = await RemoteAssetsService.instance
.getAssetPath(tokenizerRemotePath);
await _runInIsolate(
(
MLComputerOperation.initializeClipTokenizer,
{'vocabPath': tokenizerVocabPath},
),
);
// Load ClipText model
final String modelName = ClipTextEncoder.instance.modelName;
final String modelRemotePath = ClipTextEncoder.instance.modelRemotePath;
final String modelPath =
await RemoteAssetsService.instance.getAssetPath(modelRemotePath);
final address = await _runInIsolate(
(
MLComputerOperation.loadModel,
{
'modelName': modelName,
'modelPath': modelPath,
},
),
) as int;
ClipTextEncoder.instance.storeSessionAddress(address);
} catch (e, s) {
_logger.severe("Could not load clip text model in MLComputer", e, s);
rethrow;
}
});
}
}

View File

@ -33,11 +33,9 @@ class ClipTextEncoder extends MlModel {
factory ClipTextEncoder() => instance;
static Future<List<double>> predict(Map args) async {
final text = args["text"];
final text = args["text"] as String;
final address = args["address"] as int;
final vocabPath = args["vocabPath"] as String;
final List<int> tokenize =
await ClipTextTokenizer.instance.tokenize(text, vocabPath);
final List<int> tokenize = await ClipTextTokenizer.instance.tokenize(text);
final int32list = Int32List.fromList(tokenize);
if (MlModel.usePlatformPlugin) {
return await _runPlatformPlugin(int32list);

View File

@ -36,15 +36,14 @@ class ClipTextTokenizer {
static final instance = ClipTextTokenizer._privateConstructor();
factory ClipTextTokenizer() => instance;
Future<List<int>> tokenize(String text, String vocabPath) async {
await _init(vocabPath);
Future<List<int>> tokenize(String text) async {
var tokens = _encode(text);
tokens =
[sot] + tokens.sublist(0, min(totalTokens - 2, tokens.length)) + [eot];
return tokens + List.filled(totalTokens - tokens.length, 0);
}
Future<void> _init(String vocabPath) async {
Future<void> init(String vocabPath) async {
if (_isInitialized) return;
final vocabFile = File(vocabPath);
final String vocabulary = await vocabFile.readAsString();