[mob][photos] Don't pause ML when in ML settings page

This commit is contained in:
laurenspriem
2024-06-04 11:33:13 +05:30
parent 9db9a18e3e
commit 4133764cb0
2 changed files with 64 additions and 25 deletions

View File

@@ -24,6 +24,7 @@ class MachineLearningController {
bool _isDeviceHealthy = true;
bool _isUserInteracting = true;
bool _canRunML = false;
bool mlInteractionOverride = false;
late Timer _userInteractionTimer;
bool get isDeviceHealthy => _isDeviceHealthy;
@@ -61,13 +62,23 @@ class MachineLearningController {
_resetTimer();
}
bool _canRunGivenUserInteraction() {
return (Platform.isIOS ? true : !_isUserInteracting) ||
mlInteractionOverride;
}
void forceOverrideML({required bool turnOn}) {
_logger.info("Forcing to turn on ML: $turnOn");
mlInteractionOverride = turnOn;
_fireControlEvent();
}
void _fireControlEvent() {
final shouldRunML =
_isDeviceHealthy && (Platform.isAndroid ? !_isUserInteracting : true);
final shouldRunML = _isDeviceHealthy && _canRunGivenUserInteraction();
if (shouldRunML != _canRunML) {
_canRunML = shouldRunML;
_logger.info(
"Firing event with $shouldRunML, device health: $_isDeviceHealthy and user interaction: $_isUserInteracting",
"Firing event: $shouldRunML (device health: $_isDeviceHealthy, user interaction: $_isUserInteracting, mlInteractionOverride: $mlInteractionOverride)",
);
Bus.instance.fire(MachineLearningControlEvent(shouldRunML));
}

View File

@@ -36,24 +36,29 @@ class MachineLearningSettingsPage extends StatefulWidget {
const MachineLearningSettingsPage({super.key});
@override
State<MachineLearningSettingsPage> createState() => _MachineLearningSettingsPageState();
State<MachineLearningSettingsPage> createState() =>
_MachineLearningSettingsPageState();
}
class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPage> {
class _MachineLearningSettingsPageState
extends State<MachineLearningSettingsPage> {
late InitializationState _state;
final EnteWakeLock _wakeLock = EnteWakeLock();
late StreamSubscription<MLFrameworkInitializationUpdateEvent> _eventSubscription;
late StreamSubscription<MLFrameworkInitializationUpdateEvent>
_eventSubscription;
@override
void initState() {
super.initState();
_eventSubscription = Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
_eventSubscription =
Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
_fetchState();
setState(() {});
});
_fetchState();
_wakeLock.enable();
MachineLearningController.instance.forceOverrideML(turnOn: true);
}
void _fetchState() {
@@ -65,6 +70,7 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
super.dispose();
_eventSubscription.cancel();
_wakeLock.disable();
MachineLearningController.instance.forceOverrideML(turnOn: false);
}
@override
@@ -118,7 +124,9 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
children: [
_getMagicSearchSettings(context),
const SizedBox(height: 12),
facesFlag ? _getFacesSearchSettings(context) : const SizedBox.shrink(),
facesFlag
? _getFacesSearchSettings(context)
: const SizedBox.shrink(),
],
),
),
@@ -150,7 +158,8 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
);
if (LocalSettings.instance.hasEnabledMagicSearch()) {
unawaited(
SemanticSearchService.instance.init(shouldSyncImmediately: true),
SemanticSearchService.instance
.init(shouldSyncImmediately: true),
);
} else {
await SemanticSearchService.instance.clearQueue();
@@ -211,7 +220,8 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
trailingWidget: ToggleSwitchWidget(
value: () => LocalSettings.instance.isFaceIndexingEnabled,
onChanged: () async {
final isEnabled = await LocalSettings.instance.toggleFaceIndexing();
final isEnabled =
await LocalSettings.instance.toggleFaceIndexing();
if (isEnabled) {
unawaited(FaceMlService.instance.ensureInitialized());
} else {
@@ -229,7 +239,9 @@ class _MachineLearningSettingsPageState extends State<MachineLearningSettingsPag
const SizedBox(
height: 12,
),
hasEnabled ? const FaceRecognitionStatusWidget() : const SizedBox.shrink(),
hasEnabled
? const FaceRecognitionStatusWidget()
: const SizedBox.shrink(),
],
);
}
@@ -252,7 +264,8 @@ class _ModelLoadingStateState extends State<ModelLoadingState> {
final Map<String, (int, int)> _progressMap = {};
@override
void initState() {
_progressStream = RemoteAssetsService.instance.progressStream.listen((event) {
_progressStream =
RemoteAssetsService.instance.progressStream.listen((event) {
final String url = event.$1;
String title = "";
if (url.contains("clip-image")) {
@@ -330,17 +343,20 @@ class MagicSearchIndexStatsWidget extends StatefulWidget {
});
@override
State<MagicSearchIndexStatsWidget> createState() => _MagicSearchIndexStatsWidgetState();
State<MagicSearchIndexStatsWidget> createState() =>
_MagicSearchIndexStatsWidgetState();
}
class _MagicSearchIndexStatsWidgetState extends State<MagicSearchIndexStatsWidget> {
class _MagicSearchIndexStatsWidgetState
extends State<MagicSearchIndexStatsWidget> {
IndexStatus? _status;
late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription;
@override
void initState() {
super.initState();
_eventSubscription = Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
_eventSubscription =
Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
_fetchIndexStatus();
});
_fetchIndexStatus();
@@ -416,10 +432,12 @@ class FaceRecognitionStatusWidget extends StatefulWidget {
});
@override
State<FaceRecognitionStatusWidget> createState() => FaceRecognitionStatusWidgetState();
State<FaceRecognitionStatusWidget> createState() =>
FaceRecognitionStatusWidgetState();
}
class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget> {
class FaceRecognitionStatusWidgetState
extends State<FaceRecognitionStatusWidget> {
Timer? _timer;
@override
void initState() {
@@ -433,15 +451,22 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
Future<(int, int, double, bool)> getIndexStatus() async {
try {
final indexedFiles =
await FaceMLDataDB.instance.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final indexedFiles = await FaceMLDataDB.instance
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final indexableFiles = (await getIndexableFileIDs()).length;
final showIndexedFiles = min(indexedFiles, indexableFiles);
final pendingFiles = max(indexableFiles - indexedFiles, 0);
final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
final bool deviceIsHealthy = MachineLearningController.instance.isDeviceHealthy;
final clusteringDoneRatio =
await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
final bool deviceIsHealthy =
MachineLearningController.instance.isDeviceHealthy;
return (showIndexedFiles, pendingFiles, clusteringDoneRatio, deviceIsHealthy);
return (
showIndexedFiles,
pendingFiles,
clusteringDoneRatio,
deviceIsHealthy
);
} catch (e, s) {
_logger.severe('Error getting face recognition status', e, s);
rethrow;
@@ -471,10 +496,12 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
final int indexedFiles = snapshot.data!.$1;
final int pendingFiles = snapshot.data!.$2;
final double clusteringDoneRatio = snapshot.data!.$3;
final double clusteringPercentage = (clusteringDoneRatio * 100).clamp(0, 100);
final double clusteringPercentage =
(clusteringDoneRatio * 100).clamp(0, 100);
final bool isDeviceHealthy = snapshot.data!.$4;
if (!isDeviceHealthy && (pendingFiles > 0 || clusteringPercentage < 99)) {
if (!isDeviceHealthy &&
(pendingFiles > 0 || clusteringPercentage < 99)) {
return MenuSectionDescriptionWidget(
content: S.of(context).indexingIsPaused,
);
@@ -524,7 +551,8 @@ class FaceRecognitionStatusWidgetState extends State<FaceRecognitionStatusWidget
key: ValueKey(
FaceMlService.instance.showClusteringIsHappening
? "currently running"
: "clustering_progress_" + clusteringPercentage.toStringAsFixed(0),
: "clustering_progress_" +
clusteringPercentage.toStringAsFixed(0),
),
),
],