In band signalling

This commit is contained in:
Manav Rathi 2024-04-20 10:10:33 +05:30
parent 093b3a67cb
commit d0b1ff5520
No known key found for this signature in database
10 changed files with 129 additions and 165 deletions

View File

@ -45,7 +45,7 @@ import {
convertToJPEG, convertToJPEG,
generateImageThumbnail, generateImageThumbnail,
} from "./services/imageProcessor"; } from "./services/imageProcessor";
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip"; import { clipImageEmbedding, clipTextEmbeddingIfAvailable } from "./services/ml-clip";
import { detectFaces, faceEmbedding } from "./services/ml-face"; import { detectFaces, faceEmbedding } from "./services/ml-face";
import { import {
clearStores, clearStores,
@ -169,8 +169,8 @@ export const attachIPCHandlers = () => {
clipImageEmbedding(jpegImageData), clipImageEmbedding(jpegImageData),
); );
ipcMain.handle("clipTextEmbedding", (_, text: string) => ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) =>
clipTextEmbedding(text), clipTextEmbeddingIfAvailable(text),
); );
ipcMain.handle("detectFaces", (_, input: Float32Array) => ipcMain.handle("detectFaces", (_, input: Float32Array) =>

View File

@ -5,86 +5,21 @@
* *
* @see `web/apps/photos/src/services/clip-service.ts` for more details. * @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/ */
import { existsSync } from "fs";
import jpeg from "jpeg-js"; import jpeg from "jpeg-js";
import fs from "node:fs/promises"; import fs from "node:fs/promises";
import * as ort from "onnxruntime-node"; import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import { CustomErrors } from "../../types/ipc";
import log from "../log"; import log from "../log";
import { writeStream } from "../stream"; import { writeStream } from "../stream";
import { generateTempFilePath } from "../temp"; import { generateTempFilePath } from "../temp";
import { deleteTempFile } from "./ffmpeg"; import { deleteTempFile } from "./ffmpeg";
import { import { makeCachedInferenceSession } from "./ml";
createInferenceSession,
downloadModel,
makeCachedInferenceSession,
modelSavePath,
} from "./ml";
const cachedCLIPImageSession = makeCachedInferenceSession( const cachedCLIPImageSession = makeCachedInferenceSession(
"clip-image-vit-32-float32.onnx", "clip-image-vit-32-float32.onnx",
351468764 /* 335.2 MB */, 351468764 /* 335.2 MB */,
); );
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let textModelDownloadInProgress = false;
/* TODO(MR): use the generic method. Then we can remove the exports for the
internal details functions that we use here */
const textModelPathDownloadingIfNeeded = async () => {
if (textModelDownloadInProgress)
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
const modelPath = modelSavePath(textModelName);
if (!existsSync(modelPath)) {
log.info("CLIP text model not found, downloading");
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== textModelByteSize) {
log.error(
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
);
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
}
}
return modelPath;
};
let _textSession: any = null;
const onnxTextSession = async () => {
if (!_textSession) {
const modelPath = await textModelPathDownloadingIfNeeded();
_textSession = await createInferenceSession(modelPath);
}
return _textSession;
};
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await generateTempFilePath(""); const tempFilePath = await generateTempFilePath("");
const imageStream = new Response(jpegImageData.buffer).body; const imageStream = new Response(jpegImageData.buffer).body;
@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => {
return embedding; return embedding;
}; };
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let _tokenizer: Tokenizer = null; let _tokenizer: Tokenizer = null;
const getTokenizer = () => { const getTokenizer = () => {
if (!_tokenizer) { if (!_tokenizer) {
@ -203,14 +143,21 @@ const getTokenizer = () => {
return _tokenizer; return _tokenizer;
}; };
export const clipTextEmbedding = async (text: string) => { export const clipTextEmbeddingIfAvailable = async (text: string) => {
const session = await Promise.race([ const sessionOrStatus = await Promise.race([
cachedCLIPTextSession(), cachedCLIPTextSession(),
new Promise<"downloading-model">((resolve) => "downloading-model",
setTimeout(() => resolve("downloading-model"), 100),
),
]); ]);
await onnxTextSession();
// Don't wait for the download to complete
if (typeof sessionOrStatus == "string") {
console.log(
"Ignoring CLIP text embedding request because model download is pending",
);
return undefined;
}
const session = sessionOrStatus;
const t1 = Date.now(); const t1 = Date.now();
const tokenizer = getTokenizer(); const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => {
() => () =>
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
); );
const textEmbedding = results["output"].data; const textEmbedding = results["output"].data as Float32Array;
return normalizeEmbedding(textEmbedding); return normalizeEmbedding(textEmbedding);
}; };

View File

@ -15,11 +15,6 @@ const cachedFaceDetectionSession = makeCachedInferenceSession(
30762872 /* 29.3 MB */, 30762872 /* 29.3 MB */,
); );
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
export const detectFaces = async (input: Float32Array) => { export const detectFaces = async (input: Float32Array) => {
const session = await cachedFaceDetectionSession(); const session = await cachedFaceDetectionSession();
const t = Date.now(); const t = Date.now();
@ -31,6 +26,11 @@ export const detectFaces = async (input: Float32Array) => {
return results["output"].data; return results["output"].data;
}; };
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
export const faceEmbedding = async (input: Float32Array) => { export const faceEmbedding = async (input: Float32Array) => {
// Dimension of each face (alias) // Dimension of each face (alias)
const mobileFaceNetFaceSize = 112; const mobileFaceNetFaceSize = 112;

View File

@ -1,5 +1,5 @@
/** /**
* @file AI/ML related functionality. * @file AI/ML related functionality, generic layer.
* *
* @see also `ml-clip.ts`, `ml-face.ts`. * @see also `ml-clip.ts`, `ml-face.ts`.
* *
@ -92,10 +92,10 @@ const modelPathDownloadingIfNeeded = async (
}; };
/** Return the path where the given {@link modelName} is meant to be saved */ /** Return the path where the given {@link modelName} is meant to be saved */
export const modelSavePath = (modelName: string) => const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName); path.join(app.getPath("userData"), "models", modelName);
export const downloadModel = async (saveLocation: string, name: string) => { const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model. // `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation); const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true }); await fs.mkdir(saveDir, { recursive: true });
@ -112,7 +112,7 @@ export const downloadModel = async (saveLocation: string, name: string) => {
/** /**
* Crete an ONNX {@link InferenceSession} with some defaults. * Crete an ONNX {@link InferenceSession} with some defaults.
*/ */
export const createInferenceSession = async (modelPath: string) => { const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, { return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1 // Restrict the number of threads to 1
intraOpNumThreads: 1, intraOpNumThreads: 1,

View File

@ -163,8 +163,10 @@ const runFFmpegCmd = (
const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> => const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
ipcRenderer.invoke("clipImageEmbedding", jpegImageData); ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
const clipTextEmbedding = (text: string): Promise<Float32Array> => const clipTextEmbeddingIfAvailable = (
ipcRenderer.invoke("clipTextEmbedding", text); text: string,
): Promise<Float32Array | undefined> =>
ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text);
const detectFaces = (input: Float32Array): Promise<Float32Array> => const detectFaces = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", input); ipcRenderer.invoke("detectFaces", input);
@ -263,42 +265,61 @@ const getElectronFilesFromGoogleZip = (
const getDirFiles = (dirPath: string): Promise<ElectronFile[]> => const getDirFiles = (dirPath: string): Promise<ElectronFile[]> =>
ipcRenderer.invoke("getDirFiles", dirPath); ipcRenderer.invoke("getDirFiles", dirPath);
// /**
// These objects exposed here will become available to the JS code in our * These objects exposed here will become available to the JS code in our
// renderer (the web/ code) as `window.ElectronAPIs.*` * renderer (the web/ code) as `window.ElectronAPIs.*`
// *
// There are a few related concepts at play here, and it might be worthwhile to * There are a few related concepts at play here, and it might be worthwhile to
// read their (excellent) documentation to get an understanding; * read their (excellent) documentation to get an understanding;
//` *`
// - ContextIsolation: * - ContextIsolation:
// https://www.electronjs.org/docs/latest/tutorial/context-isolation * https://www.electronjs.org/docs/latest/tutorial/context-isolation
// *
// - IPC https://www.electronjs.org/docs/latest/tutorial/ipc * - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
// *
// [Note: Transferring large amount of data over IPC] * ---
// *
// Electron's IPC implementation uses the HTML standard Structured Clone * [Note: Custom errors across Electron/Renderer boundary]
// Algorithm to serialize objects passed between processes. *
// https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization * If we need to identify errors thrown by the main process when invoked from
// * the renderer process, we can only use the `message` field because:
// In particular, ArrayBuffer is eligible for structured cloning. *
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm * > Errors thrown throw `handle` in the main process are not transparent as
// * > they are serialized and only the `message` property from the original error
// Also, ArrayBuffer is "transferable", which means it is a zero-copy operation * > is provided to the renderer process.
// operation when it happens across threads. * >
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects * > - https://www.electronjs.org/docs/latest/tutorial/ipc
// * >
// In our case though, we're not dealing with threads but separate processes. So * > Ref: https://github.com/electron/electron/issues/24427
// the ArrayBuffer will be copied: *
// > "parameters, errors and return values are **copied** when they're sent over * ---
// the bridge". *
// https://www.electronjs.org/docs/latest/api/context-bridge#methods * [Note: Transferring large amount of data over IPC]
// *
// The copy itself is relatively fast, but the problem with transfering large * Electron's IPC implementation uses the HTML standard Structured Clone
// amounts of data is potentially running out of memory during the copy. * Algorithm to serialize objects passed between processes.
// * https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
// For an alternative, see [Note: IPC streams]. *
// * In particular, ArrayBuffer is eligible for structured cloning.
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
*
* Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
* operation when it happens across threads.
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
*
* In our case though, we're not dealing with threads but separate processes. So
* the ArrayBuffer will be copied:
*
* > "parameters, errors and return values are **copied** when they're sent over
* > the bridge".
* >
* > https://www.electronjs.org/docs/latest/api/context-bridge#methods
*
* The copy itself is relatively fast, but the problem with transfering large
* amounts of data is potentially running out of memory during the copy.
*
* For an alternative, see [Note: IPC streams].
*/
contextBridge.exposeInMainWorld("electron", { contextBridge.exposeInMainWorld("electron", {
// - General // - General
@ -340,7 +361,7 @@ contextBridge.exposeInMainWorld("electron", {
// - ML // - ML
clipImageEmbedding, clipImageEmbedding,
clipTextEmbedding, clipTextEmbeddingIfAvailable,
detectFaces, detectFaces,
faceEmbedding, faceEmbedding,

View File

@ -33,25 +33,10 @@ export interface PendingUploads {
/** /**
* Errors that have special semantics on the web side. * Errors that have special semantics on the web side.
*
* [Note: Custom errors across Electron/Renderer boundary]
*
* We need to use the `message` field to disambiguate between errors thrown by
* the main process when invoked from the renderer process. This is because:
*
* > Errors thrown throw `handle` in the main process are not transparent as
* > they are serialized and only the `message` property from the original error
* > is provided to the renderer process.
* >
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
* >
* > Ref: https://github.com/electron/electron/issues/24427
*/ */
export const CustomErrors = { export const CustomErrors = {
WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED: WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
"Windows native image processing is not supported", "Windows native image processing is not supported",
MODEL_DOWNLOAD_PENDING:
"Model download pending, skipping clip search request",
}; };
/** /**

View File

@ -184,8 +184,8 @@ class CLIPService {
} }
}; };
getTextEmbedding = async (text: string) => { getTextEmbeddingIfAvailable = async (text: string) => {
return ensureElectron().clipTextEmbedding(text); return ensureElectron().clipTextEmbeddingIfAvailable(text);
}; };
private runClipEmbeddingExtraction = async (canceller: AbortController) => { private runClipEmbeddingExtraction = async (canceller: AbortController) => {

View File

@ -1,5 +1,4 @@
import log from "@/next/log"; import log from "@/next/log";
import { CustomError } from "@ente/shared/error";
import * as chrono from "chrono-node"; import * as chrono from "chrono-node";
import { FILE_TYPE } from "constants/file"; import { FILE_TYPE } from "constants/file";
import { t } from "i18next"; import { t } from "i18next";
@ -287,24 +286,20 @@ async function getLocationSuggestions(searchPhrase: string) {
return [...locationTagSuggestions, ...citySearchSuggestions]; return [...locationTagSuggestions, ...citySearchSuggestions];
} }
async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> { async function getClipSuggestion(
try { searchPhrase: string,
if (!clipService.isPlatformSupported()) { ): Promise<Suggestion | undefined> {
return null; if (!clipService.isPlatformSupported()) {
}
const clipResults = await searchClip(searchPhrase);
return {
type: SuggestionType.CLIP,
value: clipResults,
label: searchPhrase,
};
} catch (e) {
if (!e.message?.includes(CustomError.MODEL_DOWNLOAD_PENDING)) {
log.error("getClipSuggestion failed", e);
}
return null; return null;
} }
const clipResults = await searchClip(searchPhrase);
if (!clipResults) return clipResults;
return {
type: SuggestionType.CLIP,
value: clipResults,
label: searchPhrase,
};
} }
function searchCollection( function searchCollection(
@ -374,9 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
return matchedLocationTags; return matchedLocationTags;
} }
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> { const searchClip = async (
searchPhrase: string,
): Promise<ClipSearchScores | undefined> => {
const textEmbedding =
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
if (!textEmbedding) return undefined;
const imageEmbeddings = await getLocalEmbeddings(); const imageEmbeddings = await getLocalEmbeddings();
const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
const clipSearchResult = new Map<number, number>( const clipSearchResult = new Map<number, number>(
( (
await Promise.all( await Promise.all(
@ -394,7 +394,7 @@ async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
); );
return clipSearchResult; return clipSearchResult;
} };
function convertSuggestionToSearchQuery(option: Suggestion): Search { function convertSuggestionToSearchQuery(option: Suggestion): Search {
switch (option.type) { switch (option.type) {

View File

@ -240,7 +240,18 @@ export interface Electron {
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>; clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
/** /**
* Return a CLIP embedding of the given image. * Return a CLIP embedding of the given image if we already have the model
* downloaded and prepped. If the model is not available return `undefined`.
*
* This differs from the other sibling ML functions in that it doesn't wait
* for the model download to finish. It does trigger a model download, but
* then immediately returns `undefined`. At some future point, when the
* model downloaded finishes, calls to this function will start returning
* the result we seek.
*
* The reason for doing it in this asymmetric way is because CLIP text
* embeddings are used as part of deducing user initiated search results,
* and we don't want to block that interaction on a large network request.
* *
* See: [Note: CLIP based magic search] * See: [Note: CLIP based magic search]
* *
@ -248,7 +259,9 @@ export interface Electron {
* *
* @returns A CLIP embedding. * @returns A CLIP embedding.
*/ */
clipTextEmbedding: (text: string) => Promise<Float32Array>; clipTextEmbeddingIfAvailable: (
text: string,
) => Promise<Float32Array | undefined>;
/** /**
* Detect faces in the given image using YOLO. * Detect faces in the given image using YOLO.

View File

@ -84,8 +84,6 @@ export const CustomError = {
ServerError: "server error", ServerError: "server error",
FILE_NOT_FOUND: "file not found", FILE_NOT_FOUND: "file not found",
UNSUPPORTED_PLATFORM: "Unsupported platform", UNSUPPORTED_PLATFORM: "Unsupported platform",
MODEL_DOWNLOAD_PENDING:
"Model download pending, skipping clip search request",
UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch", UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
URL_ALREADY_SET: "url already set", URL_ALREADY_SET: "url already set",
FILE_CONVERSION_FAILED: "file conversion failed", FILE_CONVERSION_FAILED: "file conversion failed",