mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
In band signalling
This commit is contained in:
parent
093b3a67cb
commit
d0b1ff5520
@ -45,7 +45,7 @@ import {
|
||||
convertToJPEG,
|
||||
generateImageThumbnail,
|
||||
} from "./services/imageProcessor";
|
||||
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
|
||||
import { clipImageEmbedding, clipTextEmbeddingIfAvailable } from "./services/ml-clip";
|
||||
import { detectFaces, faceEmbedding } from "./services/ml-face";
|
||||
import {
|
||||
clearStores,
|
||||
@ -169,8 +169,8 @@ export const attachIPCHandlers = () => {
|
||||
clipImageEmbedding(jpegImageData),
|
||||
);
|
||||
|
||||
ipcMain.handle("clipTextEmbedding", (_, text: string) =>
|
||||
clipTextEmbedding(text),
|
||||
ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) =>
|
||||
clipTextEmbeddingIfAvailable(text),
|
||||
);
|
||||
|
||||
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
|
||||
|
@ -5,86 +5,21 @@
|
||||
*
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
|
||||
*/
|
||||
import { existsSync } from "fs";
|
||||
import jpeg from "jpeg-js";
|
||||
import fs from "node:fs/promises";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
||||
import { CustomErrors } from "../../types/ipc";
|
||||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
import { generateTempFilePath } from "../temp";
|
||||
import { deleteTempFile } from "./ffmpeg";
|
||||
import {
|
||||
createInferenceSession,
|
||||
downloadModel,
|
||||
makeCachedInferenceSession,
|
||||
modelSavePath,
|
||||
} from "./ml";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let textModelDownloadInProgress = false;
|
||||
|
||||
/* TODO(MR): use the generic method. Then we can remove the exports for the
|
||||
internal details functions that we use here */
|
||||
const textModelPathDownloadingIfNeeded = async () => {
|
||||
if (textModelDownloadInProgress)
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
|
||||
const modelPath = modelSavePath(textModelName);
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP text model not found, downloading");
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelPath)).size;
|
||||
if (localFileSize !== textModelByteSize) {
|
||||
log.error(
|
||||
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
|
||||
);
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
let _textSession: any = null;
|
||||
|
||||
const onnxTextSession = async () => {
|
||||
if (!_textSession) {
|
||||
const modelPath = await textModelPathDownloadingIfNeeded();
|
||||
_textSession = await createInferenceSession(modelPath);
|
||||
}
|
||||
return _textSession;
|
||||
};
|
||||
|
||||
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
||||
const tempFilePath = await generateTempFilePath("");
|
||||
const imageStream = new Response(jpegImageData.buffer).body;
|
||||
@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
return embedding;
|
||||
};
|
||||
|
||||
const cachedCLIPTextSession = makeCachedInferenceSession(
|
||||
"clip-text-vit-32-uint8.onnx",
|
||||
64173509 /* 61.2 MB */,
|
||||
);
|
||||
|
||||
let _tokenizer: Tokenizer = null;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) {
|
||||
@ -203,14 +143,21 @@ const getTokenizer = () => {
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
export const clipTextEmbedding = async (text: string) => {
|
||||
const session = await Promise.race([
|
||||
export const clipTextEmbeddingIfAvailable = async (text: string) => {
|
||||
const sessionOrStatus = await Promise.race([
|
||||
cachedCLIPTextSession(),
|
||||
new Promise<"downloading-model">((resolve) =>
|
||||
setTimeout(() => resolve("downloading-model"), 100),
|
||||
),
|
||||
"downloading-model",
|
||||
]);
|
||||
await onnxTextSession();
|
||||
|
||||
// Don't wait for the download to complete
|
||||
if (typeof sessionOrStatus == "string") {
|
||||
console.log(
|
||||
"Ignoring CLIP text embedding request because model download is pending",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const session = sessionOrStatus;
|
||||
const t1 = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => {
|
||||
() =>
|
||||
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = results["output"].data;
|
||||
const textEmbedding = results["output"].data as Float32Array;
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
};
|
||||
|
@ -15,11 +15,6 @@ const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const t = Date.now();
|
||||
@ -31,6 +26,11 @@ export const detectFaces = async (input: Float32Array) => {
|
||||
return results["output"].data;
|
||||
};
|
||||
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
export const faceEmbedding = async (input: Float32Array) => {
|
||||
// Dimension of each face (alias)
|
||||
const mobileFaceNetFaceSize = 112;
|
||||
|
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* @file AI/ML related functionality.
|
||||
* @file AI/ML related functionality, generic layer.
|
||||
*
|
||||
* @see also `ml-clip.ts`, `ml-face.ts`.
|
||||
*
|
||||
@ -92,10 +92,10 @@ const modelPathDownloadingIfNeeded = async (
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
export const modelSavePath = (modelName: string) =>
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
export const downloadModel = async (saveLocation: string, name: string) => {
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
@ -112,7 +112,7 @@ export const downloadModel = async (saveLocation: string, name: string) => {
|
||||
/**
|
||||
* Crete an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
export const createInferenceSession = async (modelPath: string) => {
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1
|
||||
intraOpNumThreads: 1,
|
||||
|
@ -163,8 +163,10 @@ const runFFmpegCmd = (
|
||||
const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
|
||||
|
||||
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("clipTextEmbedding", text);
|
||||
const clipTextEmbeddingIfAvailable = (
|
||||
text: string,
|
||||
): Promise<Float32Array | undefined> =>
|
||||
ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text);
|
||||
|
||||
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("detectFaces", input);
|
||||
@ -263,42 +265,61 @@ const getElectronFilesFromGoogleZip = (
|
||||
const getDirFiles = (dirPath: string): Promise<ElectronFile[]> =>
|
||||
ipcRenderer.invoke("getDirFiles", dirPath);
|
||||
|
||||
//
|
||||
// These objects exposed here will become available to the JS code in our
|
||||
// renderer (the web/ code) as `window.ElectronAPIs.*`
|
||||
//
|
||||
// There are a few related concepts at play here, and it might be worthwhile to
|
||||
// read their (excellent) documentation to get an understanding;
|
||||
//`
|
||||
// - ContextIsolation:
|
||||
// https://www.electronjs.org/docs/latest/tutorial/context-isolation
|
||||
//
|
||||
// - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
|
||||
//
|
||||
// [Note: Transferring large amount of data over IPC]
|
||||
//
|
||||
// Electron's IPC implementation uses the HTML standard Structured Clone
|
||||
// Algorithm to serialize objects passed between processes.
|
||||
// https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
|
||||
//
|
||||
// In particular, ArrayBuffer is eligible for structured cloning.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
|
||||
//
|
||||
// Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
|
||||
// operation when it happens across threads.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
|
||||
//
|
||||
// In our case though, we're not dealing with threads but separate processes. So
|
||||
// the ArrayBuffer will be copied:
|
||||
// > "parameters, errors and return values are **copied** when they're sent over
|
||||
// the bridge".
|
||||
// https://www.electronjs.org/docs/latest/api/context-bridge#methods
|
||||
//
|
||||
// The copy itself is relatively fast, but the problem with transfering large
|
||||
// amounts of data is potentially running out of memory during the copy.
|
||||
//
|
||||
// For an alternative, see [Note: IPC streams].
|
||||
//
|
||||
/**
|
||||
* These objects exposed here will become available to the JS code in our
|
||||
* renderer (the web/ code) as `window.ElectronAPIs.*`
|
||||
*
|
||||
* There are a few related concepts at play here, and it might be worthwhile to
|
||||
* read their (excellent) documentation to get an understanding;
|
||||
*`
|
||||
* - ContextIsolation:
|
||||
* https://www.electronjs.org/docs/latest/tutorial/context-isolation
|
||||
*
|
||||
* - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
|
||||
*
|
||||
* ---
|
||||
*
|
||||
* [Note: Custom errors across Electron/Renderer boundary]
|
||||
*
|
||||
* If we need to identify errors thrown by the main process when invoked from
|
||||
* the renderer process, we can only use the `message` field because:
|
||||
*
|
||||
* > Errors thrown throw `handle` in the main process are not transparent as
|
||||
* > they are serialized and only the `message` property from the original error
|
||||
* > is provided to the renderer process.
|
||||
* >
|
||||
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
|
||||
* >
|
||||
* > Ref: https://github.com/electron/electron/issues/24427
|
||||
*
|
||||
* ---
|
||||
*
|
||||
* [Note: Transferring large amount of data over IPC]
|
||||
*
|
||||
* Electron's IPC implementation uses the HTML standard Structured Clone
|
||||
* Algorithm to serialize objects passed between processes.
|
||||
* https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
|
||||
*
|
||||
* In particular, ArrayBuffer is eligible for structured cloning.
|
||||
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
|
||||
*
|
||||
* Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
|
||||
* operation when it happens across threads.
|
||||
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
|
||||
*
|
||||
* In our case though, we're not dealing with threads but separate processes. So
|
||||
* the ArrayBuffer will be copied:
|
||||
*
|
||||
* > "parameters, errors and return values are **copied** when they're sent over
|
||||
* > the bridge".
|
||||
* >
|
||||
* > https://www.electronjs.org/docs/latest/api/context-bridge#methods
|
||||
*
|
||||
* The copy itself is relatively fast, but the problem with transfering large
|
||||
* amounts of data is potentially running out of memory during the copy.
|
||||
*
|
||||
* For an alternative, see [Note: IPC streams].
|
||||
*/
|
||||
contextBridge.exposeInMainWorld("electron", {
|
||||
// - General
|
||||
|
||||
@ -340,7 +361,7 @@ contextBridge.exposeInMainWorld("electron", {
|
||||
// - ML
|
||||
|
||||
clipImageEmbedding,
|
||||
clipTextEmbedding,
|
||||
clipTextEmbeddingIfAvailable,
|
||||
detectFaces,
|
||||
faceEmbedding,
|
||||
|
||||
|
@ -33,25 +33,10 @@ export interface PendingUploads {
|
||||
|
||||
/**
|
||||
* Errors that have special semantics on the web side.
|
||||
*
|
||||
* [Note: Custom errors across Electron/Renderer boundary]
|
||||
*
|
||||
* We need to use the `message` field to disambiguate between errors thrown by
|
||||
* the main process when invoked from the renderer process. This is because:
|
||||
*
|
||||
* > Errors thrown throw `handle` in the main process are not transparent as
|
||||
* > they are serialized and only the `message` property from the original error
|
||||
* > is provided to the renderer process.
|
||||
* >
|
||||
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
|
||||
* >
|
||||
* > Ref: https://github.com/electron/electron/issues/24427
|
||||
*/
|
||||
export const CustomErrors = {
|
||||
WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
|
||||
"Windows native image processing is not supported",
|
||||
MODEL_DOWNLOAD_PENDING:
|
||||
"Model download pending, skipping clip search request",
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -184,8 +184,8 @@ class CLIPService {
|
||||
}
|
||||
};
|
||||
|
||||
getTextEmbedding = async (text: string) => {
|
||||
return ensureElectron().clipTextEmbedding(text);
|
||||
getTextEmbeddingIfAvailable = async (text: string) => {
|
||||
return ensureElectron().clipTextEmbeddingIfAvailable(text);
|
||||
};
|
||||
|
||||
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
|
||||
|
@ -1,5 +1,4 @@
|
||||
import log from "@/next/log";
|
||||
import { CustomError } from "@ente/shared/error";
|
||||
import * as chrono from "chrono-node";
|
||||
import { FILE_TYPE } from "constants/file";
|
||||
import { t } from "i18next";
|
||||
@ -287,24 +286,20 @@ async function getLocationSuggestions(searchPhrase: string) {
|
||||
return [...locationTagSuggestions, ...citySearchSuggestions];
|
||||
}
|
||||
|
||||
async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
|
||||
try {
|
||||
if (!clipService.isPlatformSupported()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const clipResults = await searchClip(searchPhrase);
|
||||
return {
|
||||
type: SuggestionType.CLIP,
|
||||
value: clipResults,
|
||||
label: searchPhrase,
|
||||
};
|
||||
} catch (e) {
|
||||
if (!e.message?.includes(CustomError.MODEL_DOWNLOAD_PENDING)) {
|
||||
log.error("getClipSuggestion failed", e);
|
||||
}
|
||||
async function getClipSuggestion(
|
||||
searchPhrase: string,
|
||||
): Promise<Suggestion | undefined> {
|
||||
if (!clipService.isPlatformSupported()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const clipResults = await searchClip(searchPhrase);
|
||||
if (!clipResults) return clipResults;
|
||||
return {
|
||||
type: SuggestionType.CLIP,
|
||||
value: clipResults,
|
||||
label: searchPhrase,
|
||||
};
|
||||
}
|
||||
|
||||
function searchCollection(
|
||||
@ -374,9 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
|
||||
return matchedLocationTags;
|
||||
}
|
||||
|
||||
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
|
||||
const searchClip = async (
|
||||
searchPhrase: string,
|
||||
): Promise<ClipSearchScores | undefined> => {
|
||||
const textEmbedding =
|
||||
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
|
||||
if (!textEmbedding) return undefined;
|
||||
|
||||
const imageEmbeddings = await getLocalEmbeddings();
|
||||
const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
|
||||
const clipSearchResult = new Map<number, number>(
|
||||
(
|
||||
await Promise.all(
|
||||
@ -394,7 +394,7 @@ async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
|
||||
);
|
||||
|
||||
return clipSearchResult;
|
||||
}
|
||||
};
|
||||
|
||||
function convertSuggestionToSearchQuery(option: Suggestion): Search {
|
||||
switch (option.type) {
|
||||
|
@ -240,7 +240,18 @@ export interface Electron {
|
||||
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
|
||||
|
||||
/**
|
||||
* Return a CLIP embedding of the given image.
|
||||
* Return a CLIP embedding of the given image if we already have the model
|
||||
* downloaded and prepped. If the model is not available return `undefined`.
|
||||
*
|
||||
* This differs from the other sibling ML functions in that it doesn't wait
|
||||
* for the model download to finish. It does trigger a model download, but
|
||||
* then immediately returns `undefined`. At some future point, when the
|
||||
* model downloaded finishes, calls to this function will start returning
|
||||
* the result we seek.
|
||||
*
|
||||
* The reason for doing it in this asymmetric way is because CLIP text
|
||||
* embeddings are used as part of deducing user initiated search results,
|
||||
* and we don't want to block that interaction on a large network request.
|
||||
*
|
||||
* See: [Note: CLIP based magic search]
|
||||
*
|
||||
@ -248,7 +259,9 @@ export interface Electron {
|
||||
*
|
||||
* @returns A CLIP embedding.
|
||||
*/
|
||||
clipTextEmbedding: (text: string) => Promise<Float32Array>;
|
||||
clipTextEmbeddingIfAvailable: (
|
||||
text: string,
|
||||
) => Promise<Float32Array | undefined>;
|
||||
|
||||
/**
|
||||
* Detect faces in the given image using YOLO.
|
||||
|
@ -84,8 +84,6 @@ export const CustomError = {
|
||||
ServerError: "server error",
|
||||
FILE_NOT_FOUND: "file not found",
|
||||
UNSUPPORTED_PLATFORM: "Unsupported platform",
|
||||
MODEL_DOWNLOAD_PENDING:
|
||||
"Model download pending, skipping clip search request",
|
||||
UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
|
||||
URL_ALREADY_SET: "url already set",
|
||||
FILE_CONVERSION_FAILED: "file conversion failed",
|
||||
|
Loading…
x
Reference in New Issue
Block a user