[mob][photos] Run clip tokenizer in isolate

This commit is contained in:
laurenspriem 2024-07-08 17:48:18 +07:00
parent abd0dedc57
commit 240099df83
3 changed files with 46 additions and 46 deletions

View File

@ -1,4 +1,3 @@
import "dart:io";
import "dart:math";
import "package:flutter/foundation.dart";
@ -6,18 +5,14 @@ import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart';
import "package:photos/services/remote_assets_service.dart";
class ClipTextEncoder extends MlModel {
static const kRemoteBucketModelPath = "clip-text-vit-32-float32-int32.onnx";
// static const kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx";
static const kRemoteBucketVocabPath = "bpe_simple_vocab_16e6.txt";
@override
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
String get kVocabRemotePath => kModelBucketEndpoint + kRemoteBucketVocabPath;
@override
Logger get logger => _logger;
static final _logger = Logger('ClipTextEncoder');
@ -30,20 +25,11 @@ class ClipTextEncoder extends MlModel {
static final instance = ClipTextEncoder._privateConstructor();
factory ClipTextEncoder() => instance;
final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
Future<void> initTokenizer() async {
final File vocabFile =
await RemoteAssetsService.instance.getAsset(kVocabRemotePath);
final String vocab = await vocabFile.readAsString();
await _tokenizer.init(vocab);
}
Future<List<double>> infer(Map args) async {
static Future<List<double>> infer(Map args) async {
final text = args["text"];
final address = args["address"] as int;
final runOptions = OrtRunOptions();
final tokenize = _tokenizer.tokenize(text);
final List<int> tokenize = await ClipTextTokenizer.instance.tokenize(text);
final data = List.filled(1, Int32List.fromList(tokenize));
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
final inputs = {'input': inputOrt};

View File

@ -2,9 +2,13 @@ import "dart:convert";
import "dart:math";
import "package:html_unescape/html_unescape.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:tuple/tuple.dart";
class OnnxTextTokenizer {
class ClipTextTokenizer {
static const String kVocabRemotePath =
"https://models.ente.io/bpe_simple_vocab_16e6.txt";
late String vocabulary;
late Map<int, String> byteEncoder;
late Map<String, int> byteDecoder;
@ -26,12 +30,35 @@ class OnnxTextTokenizer {
late int sot;
late int eot;
OnnxTextTokenizer();
bool _isInitialized = false;
// Async method since the loadFile returns a Future and dart constructor cannot be async
Future<void> init(String vocabulary) async {
// Singleton pattern
ClipTextTokenizer._privateConstructor();
static final instance = ClipTextTokenizer._privateConstructor();
factory ClipTextTokenizer() => instance;
Future<List<int>> tokenize(
String text, {
int nText = 76,
bool pad = true,
}) async {
await _init();
var tokens = _encode(text);
tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot];
if (pad) {
return tokens + List.filled(nText + 1 - tokens.length, 0);
} else {
return tokens;
}
}
Future<void> _init() async {
if (_isInitialized) return;
final vocabFile =
await RemoteAssetsService.instance.getAsset(kVocabRemotePath);
final String vocabulary = await vocabFile.readAsString();
this.vocabulary = vocabulary;
byteEncoder = bytesToUnicode();
byteEncoder = _bytesToUnicode();
byteDecoder = byteEncoder.map((k, v) => MapEntry(v, k));
var split = vocabulary.split('\n');
@ -58,28 +85,29 @@ class OnnxTextTokenizer {
sot = encoder['<|startoftext|>']!;
eot = encoder['<|endoftext|>']!;
_isInitialized = true;
}
List<int> encode(String text) {
List<int> _encode(String text) {
final List<int> bpeTokens = [];
text = whitespaceClean(basicClean(text)).toLowerCase();
text = _whitespaceClean(_basicClean(text)).toLowerCase();
for (Match match in pat.allMatches(text)) {
String token = match[0]!;
token = utf8.encode(token).map((b) => byteEncoder[b]).join();
bpe(token)
_bpe(token)
.split(' ')
.forEach((bpeToken) => bpeTokens.add(encoder[bpeToken]!));
}
return bpeTokens;
}
String bpe(String token) {
String _bpe(String token) {
if (cache.containsKey(token)) {
return cache[token]!;
}
var word = token.split('').map((char) => char).toList();
word[word.length - 1] = '${word.last}</w>';
var pairs = getPairs(word);
var pairs = _getPairs(word);
if (pairs.isEmpty) {
return '$token</w>';
}
@ -123,7 +151,7 @@ class OnnxTextTokenizer {
if (word.length == 1) {
break;
} else {
pairs = getPairs(word);
pairs = _getPairs(word);
}
}
final wordStr = word.join(' ');
@ -131,21 +159,7 @@ class OnnxTextTokenizer {
return wordStr;
}
List<int> tokenize(String text, {int nText = 76, bool pad = true}) {
var tokens = encode(text);
tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot];
if (pad) {
return tokens + List.filled(nText + 1 - tokens.length, 0);
} else {
return tokens;
}
}
List<int> pad(List<int> x, int padLength) {
return x + List.filled(padLength - x.length, 0);
}
Map<int, String> bytesToUnicode() {
Map<int, String> _bytesToUnicode() {
final List<int> bs = [];
for (int i = '!'.codeUnitAt(0); i <= '~'.codeUnitAt(0); i++) {
bs.add(i);
@ -171,7 +185,7 @@ class OnnxTextTokenizer {
return Map.fromIterables(bs, ds);
}
Set<Tuple2<String, String>> getPairs(List<String> word) {
Set<Tuple2<String, String>> _getPairs(List<String> word) {
final Set<Tuple2<String, String>> pairs = {};
String prevChar = word[0];
for (var i = 1; i < word.length; i++) {
@ -181,13 +195,13 @@ class OnnxTextTokenizer {
return pairs;
}
String basicClean(String text) {
String _basicClean(String text) {
final unescape = HtmlUnescape();
text = unescape.convert(unescape.convert(text));
return text.trim();
}
String whitespaceClean(String text) {
String _whitespaceClean(String text) {
text = text.replaceAll(RegExp(r'\s+'), ' ');
return text.trim();
}

View File

@ -297,7 +297,7 @@ class SemanticSearchService {
try {
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
final textEmbedding = await _computer.compute(
ClipTextEncoder.instance.infer,
ClipTextEncoder.infer,
param: {
"text": query,
"address": clipAddress,