mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[mob][photos] Run clip tokenizer in isolate
This commit is contained in:
parent
abd0dedc57
commit
240099df83
@ -1,4 +1,3 @@
|
||||
import "dart:io";
|
||||
import "dart:math";
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
@ -6,18 +5,14 @@ import "package:logging/logging.dart";
|
||||
import "package:onnxruntime/onnxruntime.dart";
|
||||
import "package:photos/services/machine_learning/ml_model.dart";
|
||||
import 'package:photos/services/machine_learning/semantic_search/clip/clip_text_tokenizer.dart';
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
|
||||
class ClipTextEncoder extends MlModel {
|
||||
static const kRemoteBucketModelPath = "clip-text-vit-32-float32-int32.onnx";
|
||||
// static const kRemoteBucketModelPath = "clip-text-vit-32-uint8.onnx";
|
||||
static const kRemoteBucketVocabPath = "bpe_simple_vocab_16e6.txt";
|
||||
|
||||
@override
|
||||
String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath;
|
||||
|
||||
String get kVocabRemotePath => kModelBucketEndpoint + kRemoteBucketVocabPath;
|
||||
|
||||
@override
|
||||
Logger get logger => _logger;
|
||||
static final _logger = Logger('ClipTextEncoder');
|
||||
@ -30,20 +25,11 @@ class ClipTextEncoder extends MlModel {
|
||||
static final instance = ClipTextEncoder._privateConstructor();
|
||||
factory ClipTextEncoder() => instance;
|
||||
|
||||
final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
|
||||
|
||||
Future<void> initTokenizer() async {
|
||||
final File vocabFile =
|
||||
await RemoteAssetsService.instance.getAsset(kVocabRemotePath);
|
||||
final String vocab = await vocabFile.readAsString();
|
||||
await _tokenizer.init(vocab);
|
||||
}
|
||||
|
||||
Future<List<double>> infer(Map args) async {
|
||||
static Future<List<double>> infer(Map args) async {
|
||||
final text = args["text"];
|
||||
final address = args["address"] as int;
|
||||
final runOptions = OrtRunOptions();
|
||||
final tokenize = _tokenizer.tokenize(text);
|
||||
final List<int> tokenize = await ClipTextTokenizer.instance.tokenize(text);
|
||||
final data = List.filled(1, Int32List.fromList(tokenize));
|
||||
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
|
||||
final inputs = {'input': inputOrt};
|
||||
|
@ -2,9 +2,13 @@ import "dart:convert";
|
||||
import "dart:math";
|
||||
|
||||
import "package:html_unescape/html_unescape.dart";
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:tuple/tuple.dart";
|
||||
|
||||
class OnnxTextTokenizer {
|
||||
class ClipTextTokenizer {
|
||||
static const String kVocabRemotePath =
|
||||
"https://models.ente.io/bpe_simple_vocab_16e6.txt";
|
||||
|
||||
late String vocabulary;
|
||||
late Map<int, String> byteEncoder;
|
||||
late Map<String, int> byteDecoder;
|
||||
@ -26,12 +30,35 @@ class OnnxTextTokenizer {
|
||||
late int sot;
|
||||
late int eot;
|
||||
|
||||
OnnxTextTokenizer();
|
||||
bool _isInitialized = false;
|
||||
|
||||
// Async method since the loadFile returns a Future and dart constructor cannot be async
|
||||
Future<void> init(String vocabulary) async {
|
||||
// Singleton pattern
|
||||
ClipTextTokenizer._privateConstructor();
|
||||
static final instance = ClipTextTokenizer._privateConstructor();
|
||||
factory ClipTextTokenizer() => instance;
|
||||
|
||||
Future<List<int>> tokenize(
|
||||
String text, {
|
||||
int nText = 76,
|
||||
bool pad = true,
|
||||
}) async {
|
||||
await _init();
|
||||
var tokens = _encode(text);
|
||||
tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot];
|
||||
if (pad) {
|
||||
return tokens + List.filled(nText + 1 - tokens.length, 0);
|
||||
} else {
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _init() async {
|
||||
if (_isInitialized) return;
|
||||
final vocabFile =
|
||||
await RemoteAssetsService.instance.getAsset(kVocabRemotePath);
|
||||
final String vocabulary = await vocabFile.readAsString();
|
||||
this.vocabulary = vocabulary;
|
||||
byteEncoder = bytesToUnicode();
|
||||
byteEncoder = _bytesToUnicode();
|
||||
byteDecoder = byteEncoder.map((k, v) => MapEntry(v, k));
|
||||
|
||||
var split = vocabulary.split('\n');
|
||||
@ -58,28 +85,29 @@ class OnnxTextTokenizer {
|
||||
|
||||
sot = encoder['<|startoftext|>']!;
|
||||
eot = encoder['<|endoftext|>']!;
|
||||
_isInitialized = true;
|
||||
}
|
||||
|
||||
List<int> encode(String text) {
|
||||
List<int> _encode(String text) {
|
||||
final List<int> bpeTokens = [];
|
||||
text = whitespaceClean(basicClean(text)).toLowerCase();
|
||||
text = _whitespaceClean(_basicClean(text)).toLowerCase();
|
||||
for (Match match in pat.allMatches(text)) {
|
||||
String token = match[0]!;
|
||||
token = utf8.encode(token).map((b) => byteEncoder[b]).join();
|
||||
bpe(token)
|
||||
_bpe(token)
|
||||
.split(' ')
|
||||
.forEach((bpeToken) => bpeTokens.add(encoder[bpeToken]!));
|
||||
}
|
||||
return bpeTokens;
|
||||
}
|
||||
|
||||
String bpe(String token) {
|
||||
String _bpe(String token) {
|
||||
if (cache.containsKey(token)) {
|
||||
return cache[token]!;
|
||||
}
|
||||
var word = token.split('').map((char) => char).toList();
|
||||
word[word.length - 1] = '${word.last}</w>';
|
||||
var pairs = getPairs(word);
|
||||
var pairs = _getPairs(word);
|
||||
if (pairs.isEmpty) {
|
||||
return '$token</w>';
|
||||
}
|
||||
@ -123,7 +151,7 @@ class OnnxTextTokenizer {
|
||||
if (word.length == 1) {
|
||||
break;
|
||||
} else {
|
||||
pairs = getPairs(word);
|
||||
pairs = _getPairs(word);
|
||||
}
|
||||
}
|
||||
final wordStr = word.join(' ');
|
||||
@ -131,21 +159,7 @@ class OnnxTextTokenizer {
|
||||
return wordStr;
|
||||
}
|
||||
|
||||
List<int> tokenize(String text, {int nText = 76, bool pad = true}) {
|
||||
var tokens = encode(text);
|
||||
tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot];
|
||||
if (pad) {
|
||||
return tokens + List.filled(nText + 1 - tokens.length, 0);
|
||||
} else {
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
|
||||
List<int> pad(List<int> x, int padLength) {
|
||||
return x + List.filled(padLength - x.length, 0);
|
||||
}
|
||||
|
||||
Map<int, String> bytesToUnicode() {
|
||||
Map<int, String> _bytesToUnicode() {
|
||||
final List<int> bs = [];
|
||||
for (int i = '!'.codeUnitAt(0); i <= '~'.codeUnitAt(0); i++) {
|
||||
bs.add(i);
|
||||
@ -171,7 +185,7 @@ class OnnxTextTokenizer {
|
||||
return Map.fromIterables(bs, ds);
|
||||
}
|
||||
|
||||
Set<Tuple2<String, String>> getPairs(List<String> word) {
|
||||
Set<Tuple2<String, String>> _getPairs(List<String> word) {
|
||||
final Set<Tuple2<String, String>> pairs = {};
|
||||
String prevChar = word[0];
|
||||
for (var i = 1; i < word.length; i++) {
|
||||
@ -181,13 +195,13 @@ class OnnxTextTokenizer {
|
||||
return pairs;
|
||||
}
|
||||
|
||||
String basicClean(String text) {
|
||||
String _basicClean(String text) {
|
||||
final unescape = HtmlUnescape();
|
||||
text = unescape.convert(unescape.convert(text));
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
String whitespaceClean(String text) {
|
||||
String _whitespaceClean(String text) {
|
||||
text = text.replaceAll(RegExp(r'\s+'), ' ');
|
||||
return text.trim();
|
||||
}
|
||||
|
@ -297,7 +297,7 @@ class SemanticSearchService {
|
||||
try {
|
||||
final int clipAddress = ClipTextEncoder.instance.sessionAddress;
|
||||
final textEmbedding = await _computer.compute(
|
||||
ClipTextEncoder.instance.infer,
|
||||
ClipTextEncoder.infer,
|
||||
param: {
|
||||
"text": query,
|
||||
"address": clipAddress,
|
||||
|
Loading…
x
Reference in New Issue
Block a user