[mob][photos] Don't pause ML when in ML settings page

This commit is contained in:
laurenspriem
2024-06-04 11:33:13 +05:30
parent 9db9a18e3e
commit 4133764cb0
2 changed files with 64 additions and 25 deletions

View File

@@ -24,6 +24,7 @@ class MachineLearningController {
bool _isDeviceHealthy = true; bool _isDeviceHealthy = true;
bool _isUserInteracting = true; bool _isUserInteracting = true;
bool _canRunML = false; bool _canRunML = false;
bool mlInteractionOverride = false;
late Timer _userInteractionTimer; late Timer _userInteractionTimer;
bool get isDeviceHealthy => _isDeviceHealthy; bool get isDeviceHealthy => _isDeviceHealthy;
@@ -61,13 +62,23 @@ class MachineLearningController {
_resetTimer(); _resetTimer();
} }
bool _canRunGivenUserInteraction() {
return (Platform.isIOS ? true : !_isUserInteracting) ||
mlInteractionOverride;
}
void forceOverrideML({required bool turnOn}) {
_logger.info("Forcing to turn on ML: $turnOn");
mlInteractionOverride = turnOn;
_fireControlEvent();
}
void _fireControlEvent() { void _fireControlEvent() {
final shouldRunML = final shouldRunML = _isDeviceHealthy && _canRunGivenUserInteraction();
_isDeviceHealthy && (Platform.isAndroid ? !_isUserInteracting : true);
if (shouldRunML != _canRunML) { if (shouldRunML != _canRunML) {
_canRunML = shouldRunML; _canRunML = shouldRunML;
_logger.info( _logger.info(
"Firing event with $shouldRunML, device health: $_isDeviceHealthy and user interaction: $_isUserInteracting", "Firing event: $shouldRunML (device health: $_isDeviceHealthy, user interaction: $_isUserInteracting, mlInteractionOverride: $mlInteractionOverride)",
); );
Bus.instance.fire(MachineLearningControlEvent(shouldRunML)); Bus.instance.fire(MachineLearningControlEvent(shouldRunML));
} }

View File

@@ -36,24 +36,29 @@ class MachineLearningSettingsPage extends StatefulWidget {
const MachineLearningSettingsPage({super.key}); const MachineLearningSettingsPage({super.key});
@override @override
State<MachineLearningSettingsPage> createState() => _MachineLearningSettingsPageState(); State<MachineLearningSettingsPage> createState() =>
_MachineLearningSettingsPageState();
} }
class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPage> { class _MachineLearningSettingsPageState
extends State<MachineLearningSettingsPage> {
late InitializationState _state; late InitializationState _state;
final EnteWakeLock _wakeLock = EnteWakeLock(); final EnteWakeLock _wakeLock = EnteWakeLock();
late StreamSubscription<MLFrameworkInitializationUpdateEvent> _eventSubscription; late StreamSubscription<MLFrameworkInitializationUpdateEvent>
_eventSubscription;
@override @override
void initState() { void initState() {
super.initState(); super.initState();
_eventSubscription = Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) { _eventSubscription =
Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
_fetchState(); _fetchState();
setState(() {}); setState(() {});
}); });
_fetchState(); _fetchState();
_wakeLock.enable(); _wakeLock.enable();
MachineLearningController.instance.forceOverrideML(turnOn: true);
} }
void _fetchState() { void _fetchState() {
@@ -65,6 +70,7 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
super.dispose(); super.dispose();
_eventSubscription.cancel(); _eventSubscription.cancel();
_wakeLock.disable(); _wakeLock.disable();
MachineLearningController.instance.forceOverrideML(turnOn: false);
} }
@override @override
@@ -118,7 +124,9 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
children: [ children: [
_getMagicSearchSettings(context), _getMagicSearchSettings(context),
const SizedBox(height: 12), const SizedBox(height: 12),
facesFlag ? _getFacesSearchSettings(context) : const SizedBox.shrink(), facesFlag
? _getFacesSearchSettings(context)
: const SizedBox.shrink(),
], ],
), ),
), ),
@@ -150,7 +158,8 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
); );
if (LocalSettings.instance.hasEnabledMagicSearch()) { if (LocalSettings.instance.hasEnabledMagicSearch()) {
unawaited( unawaited(
SemanticSearchService.instance.init(shouldSyncImmediately: true), SemanticSearchService.instance
.init(shouldSyncImmediately: true),
); );
} else { } else {
await SemanticSearchService.instance.clearQueue(); await SemanticSearchService.instance.clearQueue();
@@ -211,7 +220,8 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
trailingWidget: ToggleSwitchWidget( trailingWidget: ToggleSwitchWidget(
value: () => LocalSettings.instance.isFaceIndexingEnabled, value: () => LocalSettings.instance.isFaceIndexingEnabled,
onChanged: () async { onChanged: () async {
final isEnabled = await LocalSettings.instance.toggleFaceIndexing(); final isEnabled =
await LocalSettings.instance.toggleFaceIndexing();
if (isEnabled) { if (isEnabled) {
unawaited(FaceMlService.instance.ensureInitialized()); unawaited(FaceMlService.instance.ensureInitialized());
} else { } else {
@@ -229,7 +239,9 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
const SizedBox( const SizedBox(
height: 12, height: 12,
), ),
hasEnabled ? const FaceRecognitionStatusWidget() : const SizedBox.shrink(), hasEnabled
? const FaceRecognitionStatusWidget()
: const SizedBox.shrink(),
], ],
); );
} }
@@ -252,7 +264,8 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
final Map<String, (int, int)> _progressMap = {}; final Map<String, (int, int)> _progressMap = {};
@override @override
void initState() { void initState() {
_progressStream = RemoteAssetsService.instance.progressStream.listen((event) { _progressStream =
RemoteAssetsService.instance.progressStream.listen((event) {
final String url = event.$1; final String url = event.$1;
String title = ""; String title = "";
if (url.contains("clip-image")) { if (url.contains("clip-image")) {
@@ -330,17 +343,20 @@ class MagicSearchIndexStatsWidget extends StatefulWidget {
}); });
@override @override
State<MagicSearchIndexStatsWidget> createState() => _MagicSearchIndexStatsWidgetState(); State<MagicSearchIndexStatsWidget> createState() =>
_MagicSearchIndexStatsWidgetState();
} }
class _MagicSearchIndexStatsWidgetState extends State<MagicSearchIndexStatsWidget> { class _MagicSearchIndexStatsWidgetState
extends State<MagicSearchIndexStatsWidget> {
IndexStatus? _status; IndexStatus? _status;
late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription; late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription;
@override @override
void initState() { void initState() {
super.initState(); super.initState();
_eventSubscription = Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) { _eventSubscription =
Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
_fetchIndexStatus(); _fetchIndexStatus();
}); });
_fetchIndexStatus(); _fetchIndexStatus();
@@ -416,10 +432,12 @@ class FaceRecognitionStatusWidget extends StatefulWidget {
}); });
@override @override
State<FaceRecognitionStatusWidget> createState() => FaceRecognitionStatusWidgetState(); State<FaceRecognitionStatusWidget> createState() =>
FaceRecognitionStatusWidgetState();
} }
class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget> { class FaceRecognitionStatusWidgetState
extends State<FaceRecognitionStatusWidget> {
Timer? _timer; Timer? _timer;
@override @override
void initState() { void initState() {
@@ -433,15 +451,22 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
Future<(int, int, double, bool)> getIndexStatus() async { Future<(int, int, double, bool)> getIndexStatus() async {
try { try {
final indexedFiles = final indexedFiles = await FaceMLDataDB.instance
await FaceMLDataDB.instance.getIndexedFileCount(minimumMlVersion: faceMlVersion); .getIndexedFileCount(minimumMlVersion: faceMlVersion);
final indexableFiles = (await getIndexableFileIDs()).length; final indexableFiles = (await getIndexableFileIDs()).length;
final showIndexedFiles = min(indexedFiles, indexableFiles); final showIndexedFiles = min(indexedFiles, indexableFiles);
final pendingFiles = max(indexableFiles - indexedFiles, 0); final pendingFiles = max(indexableFiles - indexedFiles, 0);
final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(); final clusteringDoneRatio =
final bool deviceIsHealthy = MachineLearningController.instance.isDeviceHealthy; await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
final bool deviceIsHealthy =
MachineLearningController.instance.isDeviceHealthy;
return (showIndexedFiles, pendingFiles, clusteringDoneRatio, deviceIsHealthy); return (
showIndexedFiles,
pendingFiles,
clusteringDoneRatio,
deviceIsHealthy
);
} catch (e, s) { } catch (e, s) {
_logger.severe('Error getting face recognition status', e, s); _logger.severe('Error getting face recognition status', e, s);
rethrow; rethrow;
@@ -471,10 +496,12 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
final int indexedFiles = snapshot.data!.$1; final int indexedFiles = snapshot.data!.$1;
final int pendingFiles = snapshot.data!.$2; final int pendingFiles = snapshot.data!.$2;
final double clusteringDoneRatio = snapshot.data!.$3; final double clusteringDoneRatio = snapshot.data!.$3;
final double clusteringPercentage = (clusteringDoneRatio * 100).clamp(0, 100); final double clusteringPercentage =
(clusteringDoneRatio * 100).clamp(0, 100);
final bool isDeviceHealthy = snapshot.data!.$4; final bool isDeviceHealthy = snapshot.data!.$4;
if (!isDeviceHealthy && (pendingFiles > 0 || clusteringPercentage < 99)) { if (!isDeviceHealthy &&
(pendingFiles > 0 || clusteringPercentage < 99)) {
return MenuSectionDescriptionWidget( return MenuSectionDescriptionWidget(
content: S.of(context).indexingIsPaused, content: S.of(context).indexingIsPaused,
); );
@@ -524,7 +551,8 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
key: ValueKey( key: ValueKey(
FaceMlService.instance.showClusteringIsHappening FaceMlService.instance.showClusteringIsHappening
? "currently running" ? "currently running"
: "clustering_progress_" + clusteringPercentage.toStringAsFixed(0), : "clustering_progress_" +
clusteringPercentage.toStringAsFixed(0),
), ),
), ),
], ],