In band signalling

This commit is contained in:
Manav Rathi 2024-04-20 10:10:33 +05:30
parent 093b3a67cb
commit d0b1ff5520
No known key found for this signature in database
10 changed files with 129 additions and 165 deletions

View File

@ -45,7 +45,7 @@ import {
convertToJPEG,
generateImageThumbnail,
} from "./services/imageProcessor";
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
import { clipImageEmbedding, clipTextEmbeddingIfAvailable } from "./services/ml-clip";
import { detectFaces, faceEmbedding } from "./services/ml-face";
import {
clearStores,
@ -169,8 +169,8 @@ export const attachIPCHandlers = () => {
clipImageEmbedding(jpegImageData),
);
ipcMain.handle("clipTextEmbedding", (_, text: string) =>
clipTextEmbedding(text),
ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) =>
clipTextEmbeddingIfAvailable(text),
);
ipcMain.handle("detectFaces", (_, input: Float32Array) =>

View File

@ -5,86 +5,21 @@
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/
import { existsSync } from "fs";
import jpeg from "jpeg-js";
import fs from "node:fs/promises";
import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import { CustomErrors } from "../../types/ipc";
import log from "../log";
import { writeStream } from "../stream";
import { generateTempFilePath } from "../temp";
import { deleteTempFile } from "./ffmpeg";
import {
createInferenceSession,
downloadModel,
makeCachedInferenceSession,
modelSavePath,
} from "./ml";
import { makeCachedInferenceSession } from "./ml";
const cachedCLIPImageSession = makeCachedInferenceSession(
"clip-image-vit-32-float32.onnx",
351468764 /* 335.2 MB */,
);
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let textModelDownloadInProgress = false;
/* TODO(MR): use the generic method. Then we can remove the exports for the
internal details functions that we use here */
const textModelPathDownloadingIfNeeded = async () => {
if (textModelDownloadInProgress)
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
const modelPath = modelSavePath(textModelName);
if (!existsSync(modelPath)) {
log.info("CLIP text model not found, downloading");
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== textModelByteSize) {
log.error(
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
);
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
}
}
return modelPath;
};
let _textSession: any = null;
const onnxTextSession = async () => {
if (!_textSession) {
const modelPath = await textModelPathDownloadingIfNeeded();
_textSession = await createInferenceSession(modelPath);
}
return _textSession;
};
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await generateTempFilePath("");
const imageStream = new Response(jpegImageData.buffer).body;
@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => {
return embedding;
};
const cachedCLIPTextSession = makeCachedInferenceSession(
"clip-text-vit-32-uint8.onnx",
64173509 /* 61.2 MB */,
);
let _tokenizer: Tokenizer = null;
const getTokenizer = () => {
if (!_tokenizer) {
@ -203,14 +143,21 @@ const getTokenizer = () => {
return _tokenizer;
};
export const clipTextEmbedding = async (text: string) => {
const session = await Promise.race([
export const clipTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrStatus = await Promise.race([
cachedCLIPTextSession(),
new Promise<"downloading-model">((resolve) =>
setTimeout(() => resolve("downloading-model"), 100),
),
"downloading-model",
]);
await onnxTextSession();
// Don't wait for the download to complete
if (typeof sessionOrStatus == "string") {
console.log(
"Ignoring CLIP text embedding request because model download is pending",
);
return undefined;
}
const session = sessionOrStatus;
const t1 = Date.now();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => {
() =>
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = results["output"].data;
const textEmbedding = results["output"].data as Float32Array;
return normalizeEmbedding(textEmbedding);
};

View File

@ -15,11 +15,6 @@ const cachedFaceDetectionSession = makeCachedInferenceSession(
30762872 /* 29.3 MB */,
);
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
export const detectFaces = async (input: Float32Array) => {
const session = await cachedFaceDetectionSession();
const t = Date.now();
@ -31,6 +26,11 @@ export const detectFaces = async (input: Float32Array) => {
return results["output"].data;
};
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
"mobilefacenet_opset15.onnx",
5286998 /* 5 MB */,
);
export const faceEmbedding = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;

View File

@ -1,5 +1,5 @@
/**
* @file AI/ML related functionality.
* @file AI/ML related functionality, generic layer.
*
* @see also `ml-clip.ts`, `ml-face.ts`.
*
@ -92,10 +92,10 @@ const modelPathDownloadingIfNeeded = async (
};
/** Return the path where the given {@link modelName} is meant to be saved */
export const modelSavePath = (modelName: string) =>
const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
export const downloadModel = async (saveLocation: string, name: string) => {
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
@ -112,7 +112,7 @@ export const downloadModel = async (saveLocation: string, name: string) => {
/**
* Crete an ONNX {@link InferenceSession} with some defaults.
*/
export const createInferenceSession = async (modelPath: string) => {
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1
intraOpNumThreads: 1,

View File

@ -163,8 +163,10 @@ const runFFmpegCmd = (
const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
ipcRenderer.invoke("clipTextEmbedding", text);
const clipTextEmbeddingIfAvailable = (
text: string,
): Promise<Float32Array | undefined> =>
ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text);
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", input);
@ -263,42 +265,61 @@ const getElectronFilesFromGoogleZip = (
const getDirFiles = (dirPath: string): Promise<ElectronFile[]> =>
ipcRenderer.invoke("getDirFiles", dirPath);
//
// These objects exposed here will become available to the JS code in our
// renderer (the web/ code) as `window.ElectronAPIs.*`
//
// There are a few related concepts at play here, and it might be worthwhile to
// read their (excellent) documentation to get an understanding;
//`
// - ContextIsolation:
// https://www.electronjs.org/docs/latest/tutorial/context-isolation
//
// - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
//
// [Note: Transferring large amount of data over IPC]
//
// Electron's IPC implementation uses the HTML standard Structured Clone
// Algorithm to serialize objects passed between processes.
// https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
//
// In particular, ArrayBuffer is eligible for structured cloning.
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
//
// Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
// operation when it happens across threads.
// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
//
// In our case though, we're not dealing with threads but separate processes. So
// the ArrayBuffer will be copied:
// > "parameters, errors and return values are **copied** when they're sent over
// the bridge".
// https://www.electronjs.org/docs/latest/api/context-bridge#methods
//
// The copy itself is relatively fast, but the problem with transfering large
// amounts of data is potentially running out of memory during the copy.
//
// For an alternative, see [Note: IPC streams].
//
/**
* These objects exposed here will become available to the JS code in our
* renderer (the web/ code) as `window.ElectronAPIs.*`
*
* There are a few related concepts at play here, and it might be worthwhile to
* read their (excellent) documentation to get an understanding;
*`
* - ContextIsolation:
* https://www.electronjs.org/docs/latest/tutorial/context-isolation
*
* - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
*
* ---
*
* [Note: Custom errors across Electron/Renderer boundary]
*
* If we need to identify errors thrown by the main process when invoked from
* the renderer process, we can only use the `message` field because:
*
* > Errors thrown throw `handle` in the main process are not transparent as
* > they are serialized and only the `message` property from the original error
* > is provided to the renderer process.
* >
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
* >
* > Ref: https://github.com/electron/electron/issues/24427
*
* ---
*
* [Note: Transferring large amount of data over IPC]
*
* Electron's IPC implementation uses the HTML standard Structured Clone
* Algorithm to serialize objects passed between processes.
* https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
*
* In particular, ArrayBuffer is eligible for structured cloning.
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
*
* Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
* operation when it happens across threads.
* https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
*
* In our case though, we're not dealing with threads but separate processes. So
* the ArrayBuffer will be copied:
*
* > "parameters, errors and return values are **copied** when they're sent over
* > the bridge".
* >
* > https://www.electronjs.org/docs/latest/api/context-bridge#methods
*
* The copy itself is relatively fast, but the problem with transfering large
* amounts of data is potentially running out of memory during the copy.
*
* For an alternative, see [Note: IPC streams].
*/
contextBridge.exposeInMainWorld("electron", {
// - General
@ -340,7 +361,7 @@ contextBridge.exposeInMainWorld("electron", {
// - ML
clipImageEmbedding,
clipTextEmbedding,
clipTextEmbeddingIfAvailable,
detectFaces,
faceEmbedding,

View File

@ -33,25 +33,10 @@ export interface PendingUploads {
/**
* Errors that have special semantics on the web side.
*
* [Note: Custom errors across Electron/Renderer boundary]
*
* We need to use the `message` field to disambiguate between errors thrown by
* the main process when invoked from the renderer process. This is because:
*
* > Errors thrown throw `handle` in the main process are not transparent as
* > they are serialized and only the `message` property from the original error
* > is provided to the renderer process.
* >
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
* >
* > Ref: https://github.com/electron/electron/issues/24427
*/
export const CustomErrors = {
WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
"Windows native image processing is not supported",
MODEL_DOWNLOAD_PENDING:
"Model download pending, skipping clip search request",
};
/**

View File

@ -184,8 +184,8 @@ class CLIPService {
}
};
getTextEmbedding = async (text: string) => {
return ensureElectron().clipTextEmbedding(text);
getTextEmbeddingIfAvailable = async (text: string) => {
return ensureElectron().clipTextEmbeddingIfAvailable(text);
};
private runClipEmbeddingExtraction = async (canceller: AbortController) => {

View File

@ -1,5 +1,4 @@
import log from "@/next/log";
import { CustomError } from "@ente/shared/error";
import * as chrono from "chrono-node";
import { FILE_TYPE } from "constants/file";
import { t } from "i18next";
@ -287,24 +286,20 @@ async function getLocationSuggestions(searchPhrase: string) {
return [...locationTagSuggestions, ...citySearchSuggestions];
}
async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
try {
if (!clipService.isPlatformSupported()) {
return null;
}
const clipResults = await searchClip(searchPhrase);
return {
type: SuggestionType.CLIP,
value: clipResults,
label: searchPhrase,
};
} catch (e) {
if (!e.message?.includes(CustomError.MODEL_DOWNLOAD_PENDING)) {
log.error("getClipSuggestion failed", e);
}
async function getClipSuggestion(
searchPhrase: string,
): Promise<Suggestion | undefined> {
if (!clipService.isPlatformSupported()) {
return null;
}
const clipResults = await searchClip(searchPhrase);
if (!clipResults) return clipResults;
return {
type: SuggestionType.CLIP,
value: clipResults,
label: searchPhrase,
};
}
function searchCollection(
@ -374,9 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
return matchedLocationTags;
}
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
const searchClip = async (
searchPhrase: string,
): Promise<ClipSearchScores | undefined> => {
const textEmbedding =
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
if (!textEmbedding) return undefined;
const imageEmbeddings = await getLocalEmbeddings();
const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
const clipSearchResult = new Map<number, number>(
(
await Promise.all(
@ -394,7 +394,7 @@ async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
);
return clipSearchResult;
}
};
function convertSuggestionToSearchQuery(option: Suggestion): Search {
switch (option.type) {

View File

@ -240,7 +240,18 @@ export interface Electron {
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image.
* Return a CLIP embedding of the given image if we already have the model
* downloaded and prepped. If the model is not available return `undefined`.
*
* This differs from the other sibling ML functions in that it doesn't wait
* for the model download to finish. It does trigger a model download, but
* then immediately returns `undefined`. At some future point, when the
* model downloaded finishes, calls to this function will start returning
* the result we seek.
*
* The reason for doing it in this asymmetric way is because CLIP text
* embeddings are used as part of deducing user initiated search results,
* and we don't want to block that interaction on a large network request.
*
* See: [Note: CLIP based magic search]
*
@ -248,7 +259,9 @@ export interface Electron {
*
* @returns A CLIP embedding.
*/
clipTextEmbedding: (text: string) => Promise<Float32Array>;
clipTextEmbeddingIfAvailable: (
text: string,
) => Promise<Float32Array | undefined>;
/**
* Detect faces in the given image using YOLO.

View File

@ -84,8 +84,6 @@ export const CustomError = {
ServerError: "server error",
FILE_NOT_FOUND: "file not found",
UNSUPPORTED_PLATFORM: "Unsupported platform",
MODEL_DOWNLOAD_PENDING:
"Model download pending, skipping clip search request",
UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
URL_ALREADY_SET: "url already set",
FILE_CONVERSION_FAILED: "file conversion failed",