[mob] Cleaner handling of decoded image dimensions

This commit is contained in:
laurenspriem 2024-04-08 15:24:14 +05:30
parent eeedf8b3c2
commit 4cb15268e9
7 changed files with 97 additions and 76 deletions

View File

@ -0,0 +1,25 @@
class Dimensions {
final int width;
final int height;
const Dimensions({required this.width, required this.height});
@override
String toString() {
return 'Dimensions(width: $width, height: $height})';
}
Map<String, int> toJson() {
return {
'width': width,
'height': height,
};
}
factory Dimensions.fromJson(Map<String, dynamic> json) {
return Dimensions(
width: json['width'] as int,
height: json['height'] as int,
);
}
}

View File

@ -1,5 +1,6 @@
import 'dart:math' show sqrt, pow; import 'dart:math' show sqrt, pow;
import 'dart:ui' show Size;
import "package:photos/face/model/dimension.dart";
abstract class Detection { abstract class Detection {
final double score; final double score;
@ -179,8 +180,8 @@ class FaceDetectionRelative extends Detection {
} }
void correctForMaintainedAspectRatio( void correctForMaintainedAspectRatio(
Size originalSize, Dimensions originalSize,
Size newSize, Dimensions newSize,
) { ) {
// Return if both are the same size, meaning no scaling was done on both width and height // Return if both are the same size, meaning no scaling was done on both width and height
if (originalSize == newSize) { if (originalSize == newSize) {

View File

@ -9,6 +9,7 @@ import "package:computer/computer.dart";
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:logging/logging.dart'; import 'package:logging/logging.dart';
import 'package:onnxruntime/onnxruntime.dart'; import 'package:onnxruntime/onnxruntime.dart';
import "package:photos/face/model/dimension.dart";
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
import 'package:photos/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart';
import 'package:photos/services/machine_learning/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart';
@ -143,7 +144,7 @@ class YoloOnnxFaceDetection {
case FaceDetectionOperation.yoloInferenceAndPostProcessing: case FaceDetectionOperation.yoloInferenceAndPostProcessing:
final inputImageList = args['inputImageList'] as Float32List; final inputImageList = args['inputImageList'] as Float32List;
final inputShape = args['inputShape'] as List<int>; final inputShape = args['inputShape'] as List<int>;
final newSize = args['newSize'] as Size; final newSize = args['newSize'] as Dimensions;
final sessionAddress = args['sessionAddress'] as int; final sessionAddress = args['sessionAddress'] as int;
final timeSentToIsolate = args['timeNow'] as DateTime; final timeSentToIsolate = args['timeNow'] as DateTime;
final delaySentToIsolate = final delaySentToIsolate =
@ -249,7 +250,7 @@ class YoloOnnxFaceDetection {
} }
/// Detects faces in the given image data. /// Detects faces in the given image data.
Future<(List<FaceDetectionRelative>, Size)> predict( Future<(List<FaceDetectionRelative>, Dimensions)> predict(
Uint8List imageData, Uint8List imageData,
) async { ) async {
assert(isInitialized); assert(isInitialized);
@ -314,7 +315,7 @@ class YoloOnnxFaceDetection {
} }
/// Detects faces in the given image data. /// Detects faces in the given image data.
static Future<(List<FaceDetectionRelative>, Size)> predictSync( static Future<(List<FaceDetectionRelative>, Dimensions)> predictSync(
ui.Image image, ui.Image image,
ByteData imageByteData, ByteData imageByteData,
int sessionAddress, int sessionAddress,
@ -384,7 +385,7 @@ class YoloOnnxFaceDetection {
} }
/// Detects faces in the given image data. /// Detects faces in the given image data.
Future<(List<FaceDetectionRelative>, Size)> predictInIsolate( Future<(List<FaceDetectionRelative>, Dimensions)> predictInIsolate(
Uint8List imageData, Uint8List imageData,
) async { ) async {
await ensureSpawnedIsolate(); await ensureSpawnedIsolate();
@ -446,7 +447,7 @@ class YoloOnnxFaceDetection {
return (relativeDetections, originalSize); return (relativeDetections, originalSize);
} }
Future<(List<FaceDetectionRelative>, Size)> predictInComputer( Future<(List<FaceDetectionRelative>, Dimensions)> predictInComputer(
String imagePath, String imagePath,
) async { ) async {
assert(isInitialized); assert(isInitialized);
@ -524,7 +525,7 @@ class YoloOnnxFaceDetection {
final stopwatchDecoding = Stopwatch()..start(); final stopwatchDecoding = Stopwatch()..start();
final List<Float32List> inputImageDataLists = []; final List<Float32List> inputImageDataLists = [];
final List<(Size, Size)> originalAndNewSizeList = []; final List<(Dimensions, Dimensions)> originalAndNewSizeList = [];
int concatenatedImageInputsLength = 0; int concatenatedImageInputsLength = 0;
for (final imageData in imageDataList) { for (final imageData in imageDataList) {
final (inputImageList, originalSize, newSize) = final (inputImageList, originalSize, newSize) =
@ -624,9 +625,9 @@ class YoloOnnxFaceDetection {
// Account for the fact that the aspect ratio was maintained // Account for the fact that the aspect ratio was maintained
for (final faceDetection in relativeDetections) { for (final faceDetection in relativeDetections) {
faceDetection.correctForMaintainedAspectRatio( faceDetection.correctForMaintainedAspectRatio(
Size( const Dimensions(
kInputWidth.toDouble(), width: kInputWidth,
kInputHeight.toDouble(), height: kInputHeight,
), ),
originalAndNewSizeList[imageOutputToUse].$2, originalAndNewSizeList[imageOutputToUse].$2,
); );
@ -653,7 +654,7 @@ class YoloOnnxFaceDetection {
static List<FaceDetectionRelative> _yoloPostProcessOutputs( static List<FaceDetectionRelative> _yoloPostProcessOutputs(
List<OrtValue?>? outputs, List<OrtValue?>? outputs,
Size newSize, Dimensions newSize,
) { ) {
// // Get output tensors // // Get output tensors
final nestedResults = final nestedResults =
@ -684,9 +685,9 @@ class YoloOnnxFaceDetection {
// Account for the fact that the aspect ratio was maintained // Account for the fact that the aspect ratio was maintained
for (final faceDetection in relativeDetections) { for (final faceDetection in relativeDetections) {
faceDetection.correctForMaintainedAspectRatio( faceDetection.correctForMaintainedAspectRatio(
Size( const Dimensions(
kInputWidth.toDouble(), width: kInputWidth,
kInputHeight.toDouble(), height: kInputHeight,
), ),
newSize, newSize,
); );
@ -735,7 +736,7 @@ class YoloOnnxFaceDetection {
) async { ) async {
final inputImageList = args['inputImageList'] as Float32List; final inputImageList = args['inputImageList'] as Float32List;
final inputShape = args['inputShape'] as List<int>; final inputShape = args['inputShape'] as List<int>;
final newSize = args['newSize'] as Size; final newSize = args['newSize'] as Dimensions;
final sessionAddress = args['sessionAddress'] as int; final sessionAddress = args['sessionAddress'] as int;
final timeSentToIsolate = args['timeNow'] as DateTime; final timeSentToIsolate = args['timeNow'] as DateTime;
final delaySentToIsolate = final delaySentToIsolate =

View File

@ -1,7 +1,8 @@
import "dart:convert" show jsonEncode, jsonDecode; import "dart:convert" show jsonEncode, jsonDecode;
import "package:flutter/material.dart" show Size, debugPrint, immutable; import "package:flutter/material.dart" show debugPrint, immutable;
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/face/model/dimension.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import 'package:photos/models/ml/ml_typedefs.dart'; import 'package:photos/models/ml/ml_typedefs.dart';
import "package:photos/models/ml/ml_versions.dart"; import "package:photos/models/ml/ml_versions.dart";
@ -284,8 +285,7 @@ class FaceMlResult {
final List<FaceResult> faces; final List<FaceResult> faces;
final Size? faceDetectionImageSize; final Dimensions decodedImageSize;
final Size? faceAlignmentImageSize;
final int mlVersion; final int mlVersion;
final bool errorOccured; final bool errorOccured;
@ -319,8 +319,7 @@ class FaceMlResult {
required this.mlVersion, required this.mlVersion,
required this.errorOccured, required this.errorOccured,
required this.onlyThumbnailUsed, required this.onlyThumbnailUsed,
required this.faceDetectionImageSize, required this.decodedImageSize,
this.faceAlignmentImageSize,
}); });
Map<String, dynamic> _toJson() => { Map<String, dynamic> _toJson() => {
@ -329,16 +328,10 @@ class FaceMlResult {
'mlVersion': mlVersion, 'mlVersion': mlVersion,
'errorOccured': errorOccured, 'errorOccured': errorOccured,
'onlyThumbnailUsed': onlyThumbnailUsed, 'onlyThumbnailUsed': onlyThumbnailUsed,
if (faceDetectionImageSize != null) 'decodedImageSize': {
'faceDetectionImageSize': { 'width': decodedImageSize.width,
'width': faceDetectionImageSize!.width, 'height': decodedImageSize.height,
'height': faceDetectionImageSize!.height, },
},
if (faceAlignmentImageSize != null)
'faceAlignmentImageSize': {
'width': faceAlignmentImageSize!.width,
'height': faceAlignmentImageSize!.height,
},
}; };
String toJsonString() => jsonEncode(_toJson()); String toJsonString() => jsonEncode(_toJson());
@ -352,18 +345,19 @@ class FaceMlResult {
mlVersion: json['mlVersion'], mlVersion: json['mlVersion'],
errorOccured: json['errorOccured'] ?? false, errorOccured: json['errorOccured'] ?? false,
onlyThumbnailUsed: json['onlyThumbnailUsed'] ?? false, onlyThumbnailUsed: json['onlyThumbnailUsed'] ?? false,
faceDetectionImageSize: json['faceDetectionImageSize'] == null decodedImageSize: json['decodedImageSize'] != null
? null ? Dimensions(
: Size( width: json['decodedImageSize']['width'],
json['faceDetectionImageSize']['width'], height: json['decodedImageSize']['height'],
json['faceDetectionImageSize']['height'], )
), : json['faceDetectionImageSize'] == null
faceAlignmentImageSize: json['faceAlignmentImageSize'] == null ? const Dimensions(width: -1, height: -1)
? null : Dimensions(
: Size( width: (json['faceDetectionImageSize']['width'] as double)
json['faceAlignmentImageSize']['width'], .truncate(),
json['faceAlignmentImageSize']['height'], height: (json['faceDetectionImageSize']['height'] as double)
), .truncate(),
),
); );
} }
@ -400,8 +394,7 @@ class FaceMlResultBuilder {
List<FaceResultBuilder> faces = <FaceResultBuilder>[]; List<FaceResultBuilder> faces = <FaceResultBuilder>[];
Size? faceDetectionImageSize; Dimensions decodedImageSize;
Size? faceAlignmentImageSize;
int mlVersion; int mlVersion;
bool errorOccured; bool errorOccured;
@ -412,6 +405,7 @@ class FaceMlResultBuilder {
this.mlVersion = faceMlVersion, this.mlVersion = faceMlVersion,
this.errorOccured = false, this.errorOccured = false,
this.onlyThumbnailUsed = false, this.onlyThumbnailUsed = false,
this.decodedImageSize = const Dimensions(width: -1, height: -1),
}); });
FaceMlResultBuilder.fromEnteFile( FaceMlResultBuilder.fromEnteFile(
@ -419,6 +413,7 @@ class FaceMlResultBuilder {
this.mlVersion = faceMlVersion, this.mlVersion = faceMlVersion,
this.errorOccured = false, this.errorOccured = false,
this.onlyThumbnailUsed = false, this.onlyThumbnailUsed = false,
this.decodedImageSize = const Dimensions(width: -1, height: -1),
}) : fileId = file.uploadedFileID ?? -1; }) : fileId = file.uploadedFileID ?? -1;
FaceMlResultBuilder.fromEnteFileID( FaceMlResultBuilder.fromEnteFileID(
@ -426,13 +421,14 @@ class FaceMlResultBuilder {
this.mlVersion = faceMlVersion, this.mlVersion = faceMlVersion,
this.errorOccured = false, this.errorOccured = false,
this.onlyThumbnailUsed = false, this.onlyThumbnailUsed = false,
this.decodedImageSize = const Dimensions(width: -1, height: -1),
}) : fileId = fileID; }) : fileId = fileID;
void addNewlyDetectedFaces( void addNewlyDetectedFaces(
List<FaceDetectionRelative> faceDetections, List<FaceDetectionRelative> faceDetections,
Size originalSize, Dimensions originalSize,
) { ) {
faceDetectionImageSize = originalSize; decodedImageSize = originalSize;
for (var i = 0; i < faceDetections.length; i++) { for (var i = 0; i < faceDetections.length; i++) {
faces.add( faces.add(
FaceResultBuilder.fromFaceDetection( FaceResultBuilder.fromFaceDetection(
@ -446,7 +442,6 @@ class FaceMlResultBuilder {
void addAlignmentResults( void addAlignmentResults(
List<AlignmentResult> alignmentResults, List<AlignmentResult> alignmentResults,
List<double> blurValues, List<double> blurValues,
Size imageSizeUsedForAlignment,
) { ) {
if (alignmentResults.length != faces.length) { if (alignmentResults.length != faces.length) {
throw Exception( throw Exception(
@ -458,7 +453,6 @@ class FaceMlResultBuilder {
faces[i].alignment = alignmentResults[i]; faces[i].alignment = alignmentResults[i];
faces[i].blurValue = blurValues[i]; faces[i].blurValue = blurValues[i];
} }
faceAlignmentImageSize = imageSizeUsedForAlignment;
} }
void addEmbeddingsToExistingFaces( void addEmbeddingsToExistingFaces(
@ -485,8 +479,7 @@ class FaceMlResultBuilder {
mlVersion: mlVersion, mlVersion: mlVersion,
errorOccured: errorOccured, errorOccured: errorOccured,
onlyThumbnailUsed: onlyThumbnailUsed, onlyThumbnailUsed: onlyThumbnailUsed,
faceDetectionImageSize: faceDetectionImageSize, decodedImageSize: decodedImageSize,
faceAlignmentImageSize: faceAlignmentImageSize,
); );
} }

View File

@ -661,13 +661,13 @@ class FaceMlService {
), ),
); );
} else { } else {
if (result.faceDetectionImageSize == null || if (result.decodedImageSize.width == -1 ||
result.faceAlignmentImageSize == null) { result.decodedImageSize.height == -1) {
_logger.severe( _logger
"faceDetectionImageSize or faceDetectionImageSize is null for image with " .severe("decodedImageSize is not stored correctly for image with "
"ID: ${enteFile.uploadedFileID}"); "ID: ${enteFile.uploadedFileID}");
_logger.info( _logger.info(
"Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.faceAlignmentImageSize!.width}x${result.faceAlignmentImageSize!.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata", "Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata",
); );
} }
for (int i = 0; i < result.faces.length; ++i) { for (int i = 0; i < result.faces.length; ++i) {
@ -697,8 +697,8 @@ class FaceMlService {
detection, detection,
faceRes.blurValue, faceRes.blurValue,
fileInfo: FileInfo( fileInfo: FileInfo(
imageHeight: result.faceDetectionImageSize!.height.truncate(), imageHeight: result.decodedImageSize.height,
imageWidth: result.faceDetectionImageSize!.width.truncate(), imageWidth: result.decodedImageSize.width,
), ),
), ),
); );
@ -714,8 +714,8 @@ class FaceMlService {
result.mlVersion, result.mlVersion,
error: result.errorOccured ? true : null, error: result.errorOccured ? true : null,
), ),
height: result.faceDetectionImageSize!.height.truncate(), height: result.decodedImageSize.height,
width: result.faceDetectionImageSize!.width.truncate(), width: result.decodedImageSize.width,
), ),
); );
await FaceMLDataDB.instance.bulkInsertFaces(faces); await FaceMLDataDB.instance.bulkInsertFaces(faces);
@ -1093,7 +1093,7 @@ class FaceMlService {
FaceMlResultBuilder? resultBuilder, FaceMlResultBuilder? resultBuilder,
}) async { }) async {
try { try {
final (alignedFaces, alignmentResults, _, blurValues, originalImageSize) = final (alignedFaces, alignmentResults, _, blurValues, _) =
await ImageMlIsolate.instance await ImageMlIsolate.instance
.preprocessMobileFaceNetOnnx(imagePath, faces); .preprocessMobileFaceNetOnnx(imagePath, faces);
@ -1101,7 +1101,6 @@ class FaceMlService {
resultBuilder.addAlignmentResults( resultBuilder.addAlignmentResults(
alignmentResults, alignmentResults,
blurValues, blurValues,
originalImageSize,
); );
} }
@ -1128,7 +1127,7 @@ class FaceMlService {
}) async { }) async {
try { try {
final stopwatch = Stopwatch()..start(); final stopwatch = Stopwatch()..start();
final (alignedFaces, alignmentResults, _, blurValues, originalImageSize) = final (alignedFaces, alignmentResults, _, blurValues, _) =
await preprocessToMobileFaceNetFloat32List( await preprocessToMobileFaceNetFloat32List(
image, image,
imageByteData, imageByteData,
@ -1143,7 +1142,6 @@ class FaceMlService {
resultBuilder.addAlignmentResults( resultBuilder.addAlignmentResults(
alignmentResults, alignmentResults,
blurValues, blurValues,
originalImageSize,
); );
} }

View File

@ -8,6 +8,7 @@ import "package:flutter/rendering.dart";
import 'package:flutter_isolate/flutter_isolate.dart'; import 'package:flutter_isolate/flutter_isolate.dart';
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/face/model/box.dart"; import "package:photos/face/model/box.dart";
import "package:photos/face/model/dimension.dart";
import 'package:photos/models/ml/ml_typedefs.dart'; import 'package:photos/models/ml/ml_typedefs.dart';
import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart';
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
@ -343,7 +344,7 @@ class ImageMlIsolate {
@Deprecated( @Deprecated(
"Old method, not needed since we now run the whole ML pipeline for faces in a single isolate", "Old method, not needed since we now run the whole ML pipeline for faces in a single isolate",
) )
Future<(Float32List, Size, Size)> preprocessImageYoloOnnx( Future<(Float32List, Dimensions, Dimensions)> preprocessImageYoloOnnx(
Uint8List imageData, { Uint8List imageData, {
required bool normalize, required bool normalize,
required int requiredWidth, required int requiredWidth,
@ -365,13 +366,13 @@ class ImageMlIsolate {
), ),
); );
final inputs = results['inputs'] as Float32List; final inputs = results['inputs'] as Float32List;
final originalSize = Size( final originalSize = Dimensions(
results['originalWidth'] as double, width:results['originalWidth'] as int,
results['originalHeight'] as double, height: results['originalHeight'] as int,
); );
final newSize = Size( final newSize = Dimensions(
results['newWidth'] as double, width: results['newWidth'] as int,
results['newHeight'] as double, height: results['newHeight'] as int,
); );
return (inputs, originalSize, newSize); return (inputs, originalSize, newSize);
} }

View File

@ -17,6 +17,7 @@ import "dart:ui";
import 'package:flutter/painting.dart' as paint show decodeImageFromList; import 'package:flutter/painting.dart' as paint show decodeImageFromList;
import 'package:ml_linalg/linalg.dart'; import 'package:ml_linalg/linalg.dart';
import "package:photos/face/model/box.dart"; import "package:photos/face/model/box.dart";
import "package:photos/face/model/dimension.dart";
import 'package:photos/models/ml/ml_typedefs.dart'; import 'package:photos/models/ml/ml_typedefs.dart';
import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart';
import 'package:photos/services/machine_learning/face_ml/face_alignment/similarity_transform.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/similarity_transform.dart';
@ -716,7 +717,8 @@ Future<(Num3DInputMatrix, Size, Size)> preprocessImageToMatrix(
return (imageMatrix, originalSize, newSize); return (imageMatrix, originalSize, newSize);
} }
Future<(Float32List, Size, Size)> preprocessImageToFloat32ChannelsFirst( Future<(Float32List, Dimensions, Dimensions)>
preprocessImageToFloat32ChannelsFirst(
Image image, Image image,
ByteData imgByteData, { ByteData imgByteData, {
required int normalization, required int normalization,
@ -730,7 +732,7 @@ Future<(Float32List, Size, Size)> preprocessImageToFloat32ChannelsFirst(
: normalization == 1 : normalization == 1
? normalizePixelRange1 ? normalizePixelRange1
: normalizePixelNoRange; : normalizePixelNoRange;
final originalSize = Size(image.width.toDouble(), image.height.toDouble()); final originalSize = Dimensions(width: image.width, height: image.height);
if (image.width == requiredWidth && image.height == requiredHeight) { if (image.width == requiredWidth && image.height == requiredHeight) {
return ( return (
@ -784,7 +786,7 @@ Future<(Float32List, Size, Size)> preprocessImageToFloat32ChannelsFirst(
return ( return (
processedBytes, processedBytes,
originalSize, originalSize,
Size(scaledWidth.toDouble(), scaledHeight.toDouble()) Dimensions(width: scaledWidth, height: scaledHeight)
); );
} }