ente/mobile/lib/utils/ml_util.dart
2024-09-10 11:12:51 -04:00

451 lines
15 KiB
Dart

import "dart:io" show File;
import "dart:math" as math show sqrt, min, max;
import "package:flutter/services.dart" show PlatformException;
import "package:logging/logging.dart";
import "package:photos/core/configuration.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/extensions/list.dart";
import "package:photos/models/file/extensions/file_props.dart";
import "package:photos/models/file/file.dart";
import "package:photos/models/file/file_type.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/dimension.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/filedata/filedata_service.dart";
import "package:photos/services/filedata/model/file_data.dart";
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
import "package:photos/services/machine_learning/ml_exceptions.dart";
import "package:photos/services/machine_learning/ml_result.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/services/search_service.dart";
import "package:photos/utils/file_util.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:photos/utils/network_util.dart";
import "package:photos/utils/thumbnail_util.dart";
final _logger = Logger("MlUtil");
enum FileDataForML { thumbnailData, fileData }
class IndexStatus {
final int indexedItems, pendingItems;
final bool? hasWifiEnabled;
IndexStatus(this.indexedItems, this.pendingItems, [this.hasWifiEnabled]);
}
class FileMLInstruction {
final EnteFile file;
bool shouldRunFaces;
bool shouldRunClip;
FileDataEntity? existingRemoteFileML;
FileMLInstruction({
required this.file,
required this.shouldRunFaces,
required this.shouldRunClip,
});
// Returns true if the file should be indexed for either faces or clip
bool get pendingML => shouldRunFaces || shouldRunClip;
}
Future<IndexStatus> getIndexStatus() async {
try {
final int indexableFiles = (await getIndexableFileIDs()).length;
final int facesIndexedFiles =
await MLDataDB.instance.getFaceIndexedFileCount();
final int clipIndexedFiles =
await MLDataDB.instance.getClipIndexedFileCount();
final int indexedFiles = math.min(facesIndexedFiles, clipIndexedFiles);
final showIndexedFiles = math.min(indexedFiles, indexableFiles);
final showPendingFiles = math.max(indexableFiles - indexedFiles, 0);
final hasWifiEnabled = await canUseHighBandwidth();
_logger.info(
"Shown IndexStatus: indexedFiles: $showIndexedFiles, pendingFiles: $showPendingFiles, hasWifiEnabled: $hasWifiEnabled. Real values: indexedFiles: $indexedFiles (faces: $facesIndexedFiles, clip: $clipIndexedFiles), indexableFiles: $indexableFiles",
);
return IndexStatus(showIndexedFiles, showPendingFiles, hasWifiEnabled);
} catch (e, s) {
_logger.severe('Error getting ML status', e, s);
rethrow;
}
}
/// Return a list of file instructions for files that should be indexed for ML
Future<List<FileMLInstruction>> getFilesForMlIndexing() async {
_logger.info('getFilesForMlIndexing called');
final time = DateTime.now();
// Get indexed fileIDs for each ML service
final Map<int, int> faceIndexedFileIDs =
await MLDataDB.instance.faceIndexedFileIds();
final Map<int, int> clipIndexedFileIDs =
await MLDataDB.instance.clipIndexedFileWithVersion();
final Set<int> queuedFiledIDs = {};
// Get all regular files and all hidden files
final enteFiles = await SearchService.instance.getAllFiles();
final hiddenFiles = await SearchService.instance.getHiddenFiles();
// Sort out what should be indexed and in what order
final List<FileMLInstruction> filesWithLocalID = [];
final List<FileMLInstruction> filesWithoutLocalID = [];
final List<FileMLInstruction> hiddenFilesToIndex = [];
for (final EnteFile enteFile in enteFiles) {
if (_skipAnalysisEnteFile(enteFile)) {
continue;
}
if (queuedFiledIDs.contains(enteFile.uploadedFileID)) {
continue;
}
queuedFiledIDs.add(enteFile.uploadedFileID!);
final shouldRunFaces =
_shouldRunIndexing(enteFile, faceIndexedFileIDs, faceMlVersion);
final shouldRunClip =
_shouldRunIndexing(enteFile, clipIndexedFileIDs, clipMlVersion);
if (!shouldRunFaces && !shouldRunClip) {
continue;
}
final instruction = FileMLInstruction(
file: enteFile,
shouldRunFaces: shouldRunFaces,
shouldRunClip: shouldRunClip,
);
if ((enteFile.localID ?? '').isEmpty) {
filesWithoutLocalID.add(instruction);
} else {
filesWithLocalID.add(instruction);
}
}
for (final EnteFile enteFile in hiddenFiles) {
if (_skipAnalysisEnteFile(enteFile)) {
continue;
}
if (queuedFiledIDs.contains(enteFile.uploadedFileID)) {
continue;
}
queuedFiledIDs.add(enteFile.uploadedFileID!);
final shouldRunFaces =
_shouldRunIndexing(enteFile, faceIndexedFileIDs, faceMlVersion);
final shouldRunClip =
_shouldRunIndexing(enteFile, clipIndexedFileIDs, clipMlVersion);
if (!shouldRunFaces && !shouldRunClip) {
continue;
}
final instruction = FileMLInstruction(
file: enteFile,
shouldRunFaces: shouldRunFaces,
shouldRunClip: shouldRunClip,
);
hiddenFilesToIndex.add(instruction);
}
final sortedBylocalID = <FileMLInstruction>[
...filesWithLocalID,
...filesWithoutLocalID,
...hiddenFilesToIndex,
];
_logger.info(
"Getting list of files to index for ML took ${DateTime.now().difference(time).inMilliseconds} ms",
);
return sortedBylocalID;
}
Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
int yieldSize,
) async* {
final List<FileMLInstruction> filesToIndex = await getFilesForMlIndexing();
final List<List<FileMLInstruction>> chunks =
filesToIndex.chunks(embeddingFetchLimit);
List<FileMLInstruction> batchToYield = [];
for (final chunk in chunks) {
if (!localSettings.remoteFetchEnabled) {
_logger.warning("remoteFetchEnabled is false, skiping embedding fetch");
final batches = chunk.chunks(yieldSize);
for (final batch in batches) {
yield batch;
}
continue;
}
final Set<int> ids = {};
final Map<int, FileMLInstruction> pendingIndex = {};
for (final instruction in chunk) {
ids.add(instruction.file.uploadedFileID!);
pendingIndex[instruction.file.uploadedFileID!] = instruction;
}
_logger.info("fetching embeddings for ${ids.length} files");
final res = await FileDataService.instance.getFilesData(ids);
_logger.info("embeddingResponse ${res.debugLog()}");
final List<Face> faces = [];
final List<ClipEmbedding> clipEmbeddings = [];
for (FileDataEntity fileMl in res.data.values) {
final existingInstruction = pendingIndex[fileMl.fileID]!;
final facesFromRemoteEmbedding = _getFacesFromRemoteEmbedding(fileMl);
//Note: Always do null check, empty value means no face was found.
if (facesFromRemoteEmbedding != null) {
faces.addAll(facesFromRemoteEmbedding);
existingInstruction.shouldRunFaces = false;
}
if (fileMl.clipEmbedding != null &&
fileMl.clipEmbedding!.version >= clipMlVersion) {
clipEmbeddings.add(
ClipEmbedding(
fileID: fileMl.fileID,
embedding: fileMl.clipEmbedding!.embedding,
version: fileMl.clipEmbedding!.version,
),
);
existingInstruction.shouldRunClip = false;
}
if (!existingInstruction.pendingML) {
pendingIndex.remove(fileMl.fileID);
} else {
existingInstruction.existingRemoteFileML = fileMl;
pendingIndex[fileMl.fileID] = existingInstruction;
}
}
for (final fileID in pendingIndex.keys) {
final instruction = pendingIndex[fileID]!;
if (instruction.pendingML) {
batchToYield.add(instruction);
if (batchToYield.length == yieldSize) {
_logger.info("queueing indexing for $yieldSize");
yield batchToYield;
batchToYield = [];
}
}
}
await MLDataDB.instance.bulkInsertFaces(faces);
await MLDataDB.instance.putMany(clipEmbeddings);
}
// Yield any remaining instructions
if (batchToYield.isNotEmpty) {
_logger.info("queueing indexing for ${batchToYield.length}");
yield batchToYield;
}
}
// Returns a list of faces from the given remote fileML. null if the version is less than the current version
// or if the remote faceEmbedding is null.
List<Face>? _getFacesFromRemoteEmbedding(FileDataEntity fileMl) {
final RemoteFaceEmbedding? remoteFaceEmbedding = fileMl.faceEmbedding;
if (_shouldDiscardRemoteEmbedding(fileMl)) {
return null;
}
final List<Face> faces = [];
if (remoteFaceEmbedding!.faces.isEmpty) {
faces.add(
Face.empty(fileMl.fileID),
);
} else {
for (final f in remoteFaceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: remoteFaceEmbedding.height,
imageWidth: remoteFaceEmbedding.width,
);
faces.add(f);
}
}
return faces;
}
bool _shouldDiscardRemoteEmbedding(FileDataEntity fileML) {
final fileID = fileML.fileID;
final RemoteFaceEmbedding? faceEmbedding = fileML.faceEmbedding;
if (faceEmbedding == null || faceEmbedding.version < faceMlVersion) {
_logger.info("Discarding remote embedding for fileID $fileID "
"because version is ${faceEmbedding?.version} and we need $faceMlVersion");
return true;
}
// are all landmarks equal?
bool allLandmarksEqual = true;
if (faceEmbedding.faces.isEmpty) {
allLandmarksEqual = false;
}
for (final face in faceEmbedding.faces) {
if (face.detection.landmarks.isEmpty) {
allLandmarksEqual = false;
break;
}
if (face.detection.landmarks.any((landmark) => landmark.x != landmark.y)) {
allLandmarksEqual = false;
break;
}
}
if (allLandmarksEqual) {
_logger.info("Discarding remote embedding for fileID $fileID "
"because landmarks are equal");
_logger.info(
faceEmbedding.faces
.map((e) => e.detection.landmarks.toString())
.toList()
.toString(),
);
return true;
}
return false;
}
Future<Set<int>> getIndexableFileIDs() async {
final fileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
return fileIDs.toSet();
}
Future<String> getImagePathForML(EnteFile enteFile) async {
String? imagePath;
final stopwatch = Stopwatch()..start();
File? file;
final bool isVideo = enteFile.fileType == FileType.video;
if (isVideo) {
try {
file = await getThumbnailForUploadedFile(enteFile);
} on PlatformException catch (e, s) {
_logger.severe(
"Could not get thumbnail for $enteFile due to PlatformException",
e,
s,
);
throw ThumbnailRetrievalException(e.toString(), s);
}
} else {
// Don't process the file if it's too large (more than 100MB)
if (enteFile.fileSize != null && enteFile.fileSize! > maxFileDownloadSize) {
throw Exception(
"FileSizeTooLargeForMobileIndexing: size is ${enteFile.fileSize}",
);
}
try {
file = await getFile(enteFile, isOrigin: true);
} catch (e, s) {
_logger.severe(
"Could not get file for $enteFile",
e,
s,
);
}
}
imagePath = file?.path;
stopwatch.stop();
_logger.info(
"Getting file data for fileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms",
);
if (imagePath == null) {
_logger.severe(
"Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID} and format ${enteFile.displayName.split('.').last} and size ${enteFile.fileSize} since its file path is null (isVideo: $isVideo)",
);
throw CouldNotRetrieveAnyFileData();
}
return imagePath;
}
bool _skipAnalysisEnteFile(EnteFile enteFile) {
// Skip if the file is not uploaded or not owned by the user
if (!enteFile.isUploaded || enteFile.isOwner == false) {
return true;
}
// I don't know how motionPhotos and livePhotos work, so I'm also just skipping them for now
if (enteFile.fileType == FileType.other) {
return true;
}
return false;
}
bool _shouldRunIndexing(
EnteFile enteFile,
Map<int, int> indexedFileIds,
int newestVersion,
) {
final id = enteFile.uploadedFileID!;
return !indexedFileIds.containsKey(id) || indexedFileIds[id]! < newestVersion;
}
void normalizeEmbedding(List<double> embedding) {
double normalization = 0;
for (int i = 0; i < embedding.length; i++) {
normalization += embedding[i] * embedding[i];
}
final double sqrtNormalization = math.sqrt(normalization);
for (int i = 0; i < embedding.length; i++) {
embedding[i] = embedding[i] / sqrtNormalization;
}
}
Future<MLResult> analyzeImageStatic(Map args) async {
try {
final int enteFileID = args["enteFileID"] as int;
final String imagePath = args["filePath"] as String;
final bool runFaces = args["runFaces"] as bool;
final bool runClip = args["runClip"] as bool;
final int faceDetectionAddress = args["faceDetectionAddress"] as int;
final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int;
final int clipImageAddress = args["clipImageAddress"] as int;
_logger.info(
"Start analyzeImageStatic for fileID $enteFileID (runFaces: $runFaces, runClip: $runClip)",
);
final startTime = DateTime.now();
// Decode the image once to use for both face detection and alignment
final (image, rawRgbaBytes) = await decodeImageFromPath(imagePath);
final decodedImageSize =
Dimensions(height: image.height, width: image.width);
final result = MLResult.fromEnteFileID(enteFileID);
result.decodedImageSize = decodedImageSize;
final decodeTime = DateTime.now();
final decodeMs = decodeTime.difference(startTime).inMilliseconds;
String faceMsString = "", clipMsString = "";
final pipelines = await Future.wait([
runFaces
? FaceRecognitionService.runFacesPipeline(
enteFileID,
image,
rawRgbaBytes,
faceDetectionAddress,
faceEmbeddingAddress,
).then((result) {
faceMsString =
", faces: ${DateTime.now().difference(decodeTime).inMilliseconds} ms";
return result;
})
: Future.value(null),
runClip
? SemanticSearchService.runClipImage(
enteFileID,
image,
rawRgbaBytes,
clipImageAddress,
).then((result) {
clipMsString =
", clip: ${DateTime.now().difference(decodeTime).inMilliseconds} ms";
return result;
})
: Future.value(null),
]);
if (pipelines[0] != null) result.faces = pipelines[0] as List<FaceResult>;
if (pipelines[1] != null) result.clip = pipelines[1] as ClipResult;
final totalMs = DateTime.now().difference(startTime).inMilliseconds;
_logger.info(
'Finished analyzeImageStatic for fileID $enteFileID, in $totalMs ms (decode: $decodeMs ms$faceMsString$clipMsString)',
);
return result;
} catch (e, s) {
_logger.severe("Could not analyze image", e, s);
rethrow;
}
}