mirror of
https://github.com/ente-io/ente.git
synced 2025-07-03 14:06:17 +00:00
## Description - Quantized the CLIP text encoder - Moved preprocessing and postprocessing of face detection inside the model - Optimised the ONNX models more wherever possible - Created a place in infra for ML version control of sorts ## Tests Have tested the changes on mobile, but not on desktop. Please carefully review the changes on desktop, especially regarding the face detection post-processing, more specifically the image (re-)size correction.
1289 lines
40 KiB
Plaintext
1289 lines
40 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Prepping MobileCLIP model for use in Ente\n",
|
|
"\n",
|
|
"[Paper](https://arxiv.org/pdf/2311.17049.pdf) | [Github](https://github.com/apple/ml-mobileclip)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setting up Pytorch weights and source code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# !mkdir mobileclip_repo\n",
|
|
"# %cd mobileclip_repo\n",
|
|
"# !git clone https://github.com/apple/ml-mobileclip.git\n",
|
|
"# %cd ml-mobileclip"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%cd mobileclip_repo/ml-mobileclip/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# !source get_pretrained_models.sh # Files will be downloaded to `checkpoints` directory.\n",
|
|
"# %cd ../.."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!uv pip install clip-benchmark>=1.4.0 datasets>=2.8.0 open-clip-torch>=2.20.0 timm>=0.9.5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.onnx\n",
|
|
"import torchvision\n",
|
|
"import torch.nn as nn\n",
|
|
"from PIL import Image\n",
|
|
"import mobileclip\n",
|
|
"import numpy as np\n",
|
|
"from numpy.linalg import norm\n",
|
|
"import onnx\n",
|
|
"import onnxruntime as ort\n",
|
|
"print(ort.__version__)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='checkpoints/mobileclip_s2.pt')\n",
|
|
"og_model = model\n",
|
|
"model.eval()\n",
|
|
"og_model.eval()\n",
|
|
"tokenizer = mobileclip.get_tokenizer('mobileclip_s2')\n",
|
|
"\n",
|
|
"image = preprocess(Image.open(\"docs/fig_accuracy_latency.png\").convert('RGB')).unsqueeze(0)\n",
|
|
"text = tokenizer([\"Hello World!\", \"a diagram\", \"a dog\", \"a cat\"])\n",
|
|
"\n",
|
|
"with torch.no_grad(), torch.cuda.amp.autocast():\n",
|
|
" image_features = model.encode_image(image)\n",
|
|
" text_features = model.encode_text(text)\n",
|
|
" image_features /= image_features.norm(dim=-1, keepdim=True)\n",
|
|
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
|
|
"\n",
|
|
" text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n",
|
|
"\n",
|
|
"print(\"Label probs:\", text_probs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%cd ../.."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# !rm -rf mobileclip_repo"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer([\"This is a tokenized string\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text_input = tokenizer([\"Hello World! This is a super duper long piece of text of at least 77 tokens, purely to make sure that indeed this is a good input without any zeros that the exporter might somehow confuse with a boolean. Apparently we're still not at 77 tokens, so I just keep on monkey typing this story in the hope that someday I have a fully tokenized string of text that is longer than the required 77 tokens. Thank you for coming to my TED talk.\"])\n",
|
|
"text_emb = model.encode_text(text_input)[0].detach().numpy()\n",
|
|
"text_emb /= norm(text_emb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"preprocess"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_singapore = Image.open(\"../data/singapore.jpg\").convert('RGBA')\n",
|
|
"image_input = preprocess(image_singapore).unsqueeze(0)\n",
|
|
"print(image_input.detach().numpy().shape)\n",
|
|
"print(1*3*256*256)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_emb = model(image_input[:,:3,:,:])[0][0].detach().numpy()\n",
|
|
"print(image_emb.shape)\n",
|
|
"print(norm(image_emb))\n",
|
|
"image_emb[0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_singapore_onnx = np.array(image_singapore)\n",
|
|
"print(image_singapore_onnx.shape)\n",
|
|
"print(image_singapore_onnx.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Export to ONNX"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"onnx_opset = 18 # use opset 18 for Resize to antialias"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Image model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class EncodeImageWrapper(nn.Module):\n",
|
|
" def __init__(self, original_model):\n",
|
|
" super(EncodeImageWrapper, self).__init__()\n",
|
|
" self.original_model = original_model\n",
|
|
"\n",
|
|
" def forward(self, input):\n",
|
|
" return self.original_model.encode_image(input)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_model_wrapper = EncodeImageWrapper(model)\n",
|
|
"image_model_wrapper.eval()\n",
|
|
"image_model_wrapper.original_model.eval()\n",
|
|
"clip_image_onnx_export_path = \"onnx_models/mobileclip_s2_image_float32.onnx\"\n",
|
|
"torch.onnx.export(image_model_wrapper, image, clip_image_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=[\"input\"], output_names=[\"output\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mobileclip_image_onnx = onnx.load(clip_image_onnx_export_path)\n",
|
|
"onnx.checker.check_model(mobileclip_image_onnx)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Text model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class EncodeTextWrapper(nn.Module):\n",
|
|
" def __init__(self, original_model):\n",
|
|
" super(EncodeTextWrapper, self).__init__()\n",
|
|
" self.original_model = original_model\n",
|
|
"\n",
|
|
" def forward(self, input):\n",
|
|
" return self.original_model.encode_text(input)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text_model_wrapper = EncodeTextWrapper(model)\n",
|
|
"text_model_wrapper.eval()\n",
|
|
"text_model_wrapper.original_model.eval()\n",
|
|
"clip_text_onnx_export_path = \"onnx_models/mobileclip_s2_text_int64.onnx\"\n",
|
|
"torch.onnx.export(text_model_wrapper, text_input, clip_text_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=['input'], output_names=['output'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Altering ONNX models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Image model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Change input name to `og_input` so we can reserve `input` for altered model that includes preprocessing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"og_input = onnx.helper.make_tensor_value_info(\n",
|
|
" name=\"og_input\",\n",
|
|
" elem_type=onnx.TensorProto.FLOAT,\n",
|
|
" shape=[1, 3, 256, 256], \n",
|
|
")\n",
|
|
"\n",
|
|
"# Update the input names in the rest of the model\n",
|
|
"for node in mobileclip_image_onnx.graph.node:\n",
|
|
" for i, input_name in enumerate(node.input):\n",
|
|
" if input_name == \"input\":\n",
|
|
" node.input[i] = \"og_input\"\n",
|
|
"\n",
|
|
"graph = onnx.helper.make_graph(\n",
|
|
" nodes=mobileclip_image_onnx.graph.node,\n",
|
|
" name=mobileclip_image_onnx.graph.name,\n",
|
|
" inputs=[og_input],\n",
|
|
" outputs=mobileclip_image_onnx.graph.output,\n",
|
|
" initializer=mobileclip_image_onnx.graph.initializer,\n",
|
|
")\n",
|
|
"mobileclip_image_onnx = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)])\n",
|
|
"onnx.save_model(mobileclip_image_onnx, clip_image_onnx_export_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Add preprocessing to the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor, create_named_value, Resize, ImageBytesToFloat, Unsqueeze, CenterCrop, Debug, ChannelsLastToChannelsFirst"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"inputs = [create_named_value(\"input_to_process\", onnx.TensorProto.UINT8, [\"H\", \"W\", \"C\"])]\n",
|
|
"\n",
|
|
"pipeline = PrePostProcessor(inputs, onnx_opset)\n",
|
|
"\n",
|
|
"pipeline.add_pre_processing(\n",
|
|
" [\n",
|
|
" Resize(256), \n",
|
|
" CenterCrop(256, 256), # Crop to 256x256. NOTE: Currently only HWC input is handled.\n",
|
|
" ChannelsLastToChannelsFirst(), # Convert to CHW\n",
|
|
" # Debug(),\n",
|
|
" ImageBytesToFloat(), # Convert to float in range 0..1 by dividing uint8 values by 255\n",
|
|
" # Debug(),\n",
|
|
" Unsqueeze([0]), # add batch, CHW --> 1CHW\n",
|
|
" # Debug(),\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"clip_image_with_preprocessing = pipeline.run(mobileclip_image_onnx)\n",
|
|
"\n",
|
|
"onnx.checker.check_model(clip_image_with_preprocessing)\n",
|
|
"clip_image_onnx_rgb_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgb.onnx\"\n",
|
|
"new_model_path = clip_image_onnx_rgb_path\n",
|
|
"onnx.save_model(clip_image_with_preprocessing, new_model_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Add a slice node so that the model can take raw RGBA data as input (as well as standard RGB)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"onnx_model = clip_image_with_preprocessing\n",
|
|
"\n",
|
|
"# Create a new input with flexible channel dimension\n",
|
|
"new_input = onnx.helper.make_tensor_value_info(\n",
|
|
" name=\"input\",\n",
|
|
" elem_type=onnx.TensorProto.UINT8,\n",
|
|
" shape=[\"H\", \"W\", \"C\"], \n",
|
|
")\n",
|
|
"\n",
|
|
"# Create constant tensors for starts, ends, and axes\n",
|
|
"starts_tensor = onnx.helper.make_tensor(\n",
|
|
" name=\"starts\",\n",
|
|
" data_type=onnx.TensorProto.INT64,\n",
|
|
" dims=[1],\n",
|
|
" vals=np.array([0], dtype=np.int64)\n",
|
|
")\n",
|
|
"ends_tensor = onnx.helper.make_tensor(\n",
|
|
" name=\"ends\",\n",
|
|
" data_type=onnx.TensorProto.INT64,\n",
|
|
" dims=[1],\n",
|
|
" vals=np.array([3], dtype=np.int64)\n",
|
|
")\n",
|
|
"axes_tensor = onnx.helper.make_tensor(\n",
|
|
" name=\"axes\",\n",
|
|
" data_type=onnx.TensorProto.INT64,\n",
|
|
" dims=[1],\n",
|
|
" vals=np.array([2], dtype=np.int64)\n",
|
|
")\n",
|
|
"new_initializers = [starts_tensor, ends_tensor, axes_tensor] + list(onnx_model.graph.initializer)\n",
|
|
"slice_node = onnx.helper.make_node(\n",
|
|
" \"Slice\",\n",
|
|
" inputs=[\"input\", \"starts\", \"ends\", \"axes\"],\n",
|
|
" outputs=[\"sliced_input\"],\n",
|
|
" name=\"slice_rgba_input_node\"\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Add the new input and Slice node to the graph\n",
|
|
"graph = onnx.helper.make_graph(\n",
|
|
" [slice_node] + list(onnx_model.graph.node), # Prepend Slice node to existing nodes\n",
|
|
" onnx_model.graph.name,\n",
|
|
" [new_input],\n",
|
|
" list(onnx_model.graph.output),\n",
|
|
" initializer=new_initializers,\n",
|
|
" value_info=onnx_model.graph.value_info,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create the new model\n",
|
|
"mobileclip_image_onnx_rgba = onnx.helper.make_model(\n",
|
|
" graph,\n",
|
|
" opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Update the input names in the rest of the model\n",
|
|
"for node in mobileclip_image_onnx_rgba.graph.node:\n",
|
|
" for i, input_name in enumerate(node.input):\n",
|
|
" if input_name == \"input_to_process\":\n",
|
|
" node.input[i] = \"sliced_input\"\n",
|
|
"\n",
|
|
"# Save the new model\n",
|
|
"onnx.checker.check_model(mobileclip_image_onnx_rgba)\n",
|
|
"clip_image_onnx_rgba_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba.onnx\"\n",
|
|
"onnx.save(mobileclip_image_onnx_rgba, clip_image_onnx_rgba_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Optimize the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_image_sim_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_sim.onnx\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!onnxsim {clip_image_onnx_rgba_path} {clip_image_sim_path}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Optimize the graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_opt_sess_options = ort.SessionOptions()\n",
|
|
"\n",
|
|
"image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL\n",
|
|
"image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC\n",
|
|
"\n",
|
|
"clip_image_opt_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_opt.onnx\"\n",
|
|
"image_opt_sess_options.optimized_model_filepath = clip_image_opt_path\n",
|
|
"\n",
|
|
"opt_image_session = ort.InferenceSession(clip_image_sim_path, image_opt_sess_options)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Add metadata to the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_image_opt = onnx.load(clip_image_opt_path)\n",
|
|
"clip_image_opt.producer_name = \"EnteMobileCLIPImageEncoder\"\n",
|
|
"clip_image_opt.doc_string = \"MobileCLIP S2 Image Encoder with built-in preprocessing. Accepts both RGB and RGBA raw bytes input (uint8) in HWC format.\"\n",
|
|
"clip_image_opt.graph.doc_string = \"\"\n",
|
|
"clip_image_opt.graph.name = \"SliceRGB+Resize+CenterCrop+ToFloat+Unsqueeze+MobileCLIP_S2_ImageEncoder\"\n",
|
|
"onnx.save(clip_image_opt, clip_image_opt_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ort_session = ort.InferenceSession(clip_image_opt_path)\n",
|
|
"onnx_emb = ort_session.run(None, {\"input\": image_singapore_onnx})[0][0]\n",
|
|
"onnx_emb /= norm(onnx_emb)\n",
|
|
"np.dot(image_emb, onnx_emb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!rm {clip_image_onnx_export_path}\n",
|
|
"!rm {clip_image_onnx_rgb_path}\n",
|
|
"!rm {clip_image_onnx_rgba_path}\n",
|
|
"!rm {clip_image_sim_path}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Text model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Make sure the model can use int32 as input"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mobileclip_text_onxx = onnx.load(clip_text_onnx_export_path)\n",
|
|
"\n",
|
|
"for tensor in mobileclip_text_onxx.graph.input:\n",
|
|
" if tensor.name == \"input\":\n",
|
|
" tensor.type.tensor_type.elem_type = onnx.TensorProto.INT32\n",
|
|
" break\n",
|
|
"\n",
|
|
"# Save the modified model\n",
|
|
"clip_text_onnx_int32_path = \"onnx_models/mobileclip_s2_text_int32.onnx\"\n",
|
|
"onnx.save(mobileclip_text_onxx, clip_text_onnx_int32_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"[Simplify](https://github.com/daquexian/onnx-simplifier) the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_text_sim_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_sim.onnx\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!onnxsim {clip_text_onnx_int32_path} {clip_text_sim_path}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Apply basic offline [graph optimizations](https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). Only do the basic optimizations offline, the extended and layout optimizations should be done online depending on execution provider and hardware."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text_opt_sess_options = ort.SessionOptions()\n",
|
|
"\n",
|
|
"text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL\n",
|
|
"text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC\n",
|
|
"\n",
|
|
"clip_text_opt_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_opt.onnx\"\n",
|
|
"text_opt_sess_options.optimized_model_filepath = clip_text_opt_path\n",
|
|
"\n",
|
|
"opt_text_session = ort.InferenceSession(clip_text_sim_path, text_opt_sess_options)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Add metadata to the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_text_opt = onnx.load(clip_text_opt_path)\n",
|
|
"clip_text_opt.producer_name = \"EnteMobileCLIPTextEncoder\"\n",
|
|
"clip_text_opt.doc_string = \"MobileCLIP S2 Text Encoder. Accepts an integer array (int32) of length 77. Longer arrays will be truncated.\"\n",
|
|
"clip_text_opt.graph.doc_string = \"\"\n",
|
|
"clip_text_opt.graph.name = \"MobileCLIP_S2_TextEncoder\"\n",
|
|
"onnx.save(clip_text_opt, clip_text_opt_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mobileclip_text_ort_sess = ort.InferenceSession(clip_text_opt_path)\n",
|
|
"text_onnx_emb = mobileclip_text_ort_sess.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
|
|
"text_onnx_emb /= norm(text_onnx_emb)\n",
|
|
"np.dot(text_emb, text_onnx_emb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!rm {clip_text_onnx_export_path}\n",
|
|
"!rm {clip_text_onnx_int32_path}\n",
|
|
"!rm {clip_text_sim_path}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Quantize text model\n",
|
|
"\n",
|
|
"https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Quantization pre-processing (not to confuse with normal pre-processing)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from onnxruntime.quantization import quant_pre_process"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_text_quantized_preprocessed_path = \"onnx_models/mobileclip_s2_text_quant_preprocessed.onnx\"\n",
|
|
"quant_pre_process(clip_text_opt_path, clip_text_quantized_preprocessed_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Dynamic quantization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"node_names = []\n",
|
|
"matmul_nodes_names = []\n",
|
|
"for node in clip_text_opt.graph.node:\n",
|
|
" node_names.append(node.name)\n",
|
|
" if node.op_type == \"MatMul\" and node.name != \"/text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul\":\n",
|
|
" matmul_nodes_names.append(node.name)\n",
|
|
"len(node_names)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_text_quantized_dynamic_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_quant.onnx\"\n",
|
|
"quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_path, nodes_to_exclude=node_names[28])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
|
|
"text_onnx_quant_dyn_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
|
|
"text_onnx_quant_dyn_emb /= norm(text_onnx_quant_dyn_emb)\n",
|
|
"np.dot(text_onnx_quant_dyn_emb, text_onnx_emb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Quantization Debugging (uncomment if you want to try it)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# exclude_amount = 1\n",
|
|
"\n",
|
|
"\n",
|
|
"# for i in range(25, 30, exclude_amount):\n",
|
|
"# begin = i\n",
|
|
"# end = min(i+exclude_amount, len(node_names))\n",
|
|
" \n",
|
|
"# clip_text_quantized_dynamic_debug_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int8dyn_opt_debug.onnx\"\n",
|
|
"# quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_debug_path, nodes_to_exclude=node_names[begin:end])\n",
|
|
"# mobileclip_text_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_text_quantized_dynamic_debug_path)\n",
|
|
"# text_onnx_quant_dyn_emb_debug = mobileclip_text_quant_dyn_ort_sess_debug.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
|
|
"# text_onnx_quant_dyn_emb_debug /= norm(text_onnx_quant_dyn_emb_debug)\n",
|
|
"# sim_debug = np.dot(text_onnx_quant_dyn_emb_debug, text_onnx_emb)\n",
|
|
"# print(f\"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"node_names[28:29]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test on a dataset of image captions. Before continuing, download the dataset from [Kaggle](https://www.kaggle.com/datasets/aladdinpersson/flickr8kimagescaptions/data) and put it in the `../data` folder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import csv\n",
|
|
"from tqdm import tqdm\n",
|
|
"import time\n",
|
|
"import copy\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"captions = []\n",
|
|
"\n",
|
|
"with open('../data/flickr8k_captions.txt', 'r', encoding='utf-8') as file:\n",
|
|
" csv_reader = csv.reader(file)\n",
|
|
" next(csv_reader)\n",
|
|
" for row in csv_reader:\n",
|
|
" captions.append(row[1])\n",
|
|
"\n",
|
|
"print(len(captions))\n",
|
|
"print(captions[:5])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test accuracy of quantized model quickly (uncomment code below)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_size = 600\n",
|
|
"similarities = np.zeros(test_size)\n",
|
|
"mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
|
|
"\n",
|
|
"for i, caption in tqdm(enumerate(captions[:test_size])):\n",
|
|
" text_input_test = tokenizer([caption])\n",
|
|
" text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()\n",
|
|
" text_emb_test /= norm(text_emb_test)\n",
|
|
" text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test.numpy().astype(\"int32\")})[0][0]\n",
|
|
" text_onnx_test_emb /= norm(text_onnx_test_emb)\n",
|
|
" similarities[i] = np.dot(text_onnx_test_emb, text_emb_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f\"Mean similarity: {similarities.mean()}\")\n",
|
|
"print(f\"Standard deviation: {similarities.std()}\")\n",
|
|
"print(f\"Minimum similarity: {similarities.min()}\")\n",
|
|
"print(f\"Maximum similarity: {similarities.max()}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test accuracy of quantized model extensively (uncomment code below)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# captions_extensive = copy.deepcopy(captions)\n",
|
|
"\n",
|
|
"# for i in range(10000):\n",
|
|
"# captions_extensive[i] = captions_extensive[i] + \" \" + captions_extensive[i + 10000] + \" \" + captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
|
|
"# captions_extensive[i + 10000] = captions_extensive[i + 10000] + \" \" + captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
|
|
"# captions_extensive[i + 20000] = captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
|
|
"# captions_extensive = captions_extensive[:40000]\n",
|
|
"\n",
|
|
"# test_size = len(captions_extensive)\n",
|
|
"# similarities_extensive = np.zeros(test_size)\n",
|
|
"# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
|
|
"\n",
|
|
"# for i, caption in tqdm(enumerate(captions_extensive[:test_size])):\n",
|
|
"# text_input_test = tokenizer([caption])\n",
|
|
"# text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()\n",
|
|
"# text_emb_test /= norm(text_emb_test)\n",
|
|
"# text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test.numpy().astype(\"int32\")})[0][0]\n",
|
|
"# text_onnx_test_emb /= norm(text_onnx_test_emb)\n",
|
|
"# similarities_extensive[i] = np.dot(text_onnx_test_emb, text_emb_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 51,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# print(f\"Mean similarity: {similarities_extensive.mean()}\")\n",
|
|
"# print(f\"Standard deviation: {similarities_extensive.std()}\")\n",
|
|
"# print(f\"Minimum similarity: {similarities_extensive.min()}\")\n",
|
|
"# print(f\"Maximum similarity: {similarities_extensive.max()}\")\n",
|
|
"# print(f\"Percentage of similarities above 0.99: {np.sum(similarities_extensive > 0.99) / len(similarities_extensive) * 100:.2f}%\")\n",
|
|
"# print(f\"Percentage of similarities above 0.995: {np.sum(similarities_extensive > 0.995) / len(similarities_extensive) * 100:.2f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Investigating the MatMul excluded from quantization to improve performance (uncomment code below)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# quant_model = onnx.load(clip_text_opt_path)\n",
|
|
"# node_name = node_names[28] # /text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul\n",
|
|
"# # use_node_name = matmul_nodes_names[8]\n",
|
|
"# use_node_name = node_name\n",
|
|
"\n",
|
|
"# # Find the MatMul node\n",
|
|
"# special_matmul_node = None\n",
|
|
"# for node in quant_model.graph.node:\n",
|
|
"# if node.op_type == 'MatMul' and node.name == use_node_name:\n",
|
|
"# special_matmul_node = node\n",
|
|
"# print(f\"MatMul node found: {special_matmul_node.name}\")\n",
|
|
"# break\n",
|
|
"\n",
|
|
"# if special_matmul_node is None:\n",
|
|
"# raise ValueError(f\"MatMul node with name '{use_node_name}' not found in the model.\")\n",
|
|
"\n",
|
|
"# # Get the weight tensor\n",
|
|
"# weight_name = special_matmul_node.input[1]\n",
|
|
"# special_weight_tensor = None\n",
|
|
"# for init in quant_model.graph.initializer:\n",
|
|
"# if init.name == weight_name:\n",
|
|
"# special_weight_tensor = init\n",
|
|
"# break\n",
|
|
"\n",
|
|
"# if special_weight_tensor is None:\n",
|
|
"# raise ValueError(f\"Weight tensor for MatMul node '{use_node_name}' not found.\")\n",
|
|
"\n",
|
|
"# special_weight_array = onnx.numpy_helper.to_array(special_weight_tensor)\n",
|
|
"\n",
|
|
"# mean = np.mean(special_weight_array)\n",
|
|
"# std = np.std(special_weight_array)\n",
|
|
"# min_val = np.min(special_weight_array)\n",
|
|
"# max_val = np.max(special_weight_array)\n",
|
|
"\n",
|
|
"# print(f\"Statistical Analysis for MatMul node '{use_node_name}':\")\n",
|
|
"# print(f\"Mean: {mean}\")\n",
|
|
"# print(f\"Standard Deviation: {std}\")\n",
|
|
"# print(f\"Minimum: {min_val}\")\n",
|
|
"# print(f\"Maximum: {max_val}\")\n",
|
|
"# print(f\"Dynamic Range: {max_val - min_val}\")\n",
|
|
"\n",
|
|
"# plt.figure(figsize=(10, 6))\n",
|
|
"# plt.hist(special_weight_array.flatten(), bins=50, edgecolor='black')\n",
|
|
"# plt.title(f\"Histogram of Weights for MatMul node '{use_node_name}'\")\n",
|
|
"# plt.xlabel(\"Weight Value\")\n",
|
|
"# plt.ylabel(\"Frequency\")\n",
|
|
"# plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test speed of quantized model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# time_test_size = 1000\n",
|
|
"# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
|
|
"# times_unquantized = np.zeros(time_test_size)\n",
|
|
"# times_quantized = np.zeros(time_test_size)\n",
|
|
"\n",
|
|
"# # Time of unquantized model\n",
|
|
"# print(\"Timing unquantized model...\")\n",
|
|
"# for i, caption in tqdm(enumerate(captions[:time_test_size])):\n",
|
|
"# text_input_test = tokenizer([caption])\n",
|
|
"# start = time.time()\n",
|
|
"# _ = model.encode_text(text_input_test)\n",
|
|
"# end = time.time()\n",
|
|
"# times_unquantized[i] = end - start\n",
|
|
"\n",
|
|
"# # Time of quantized model\n",
|
|
"# print(\"Timing quantized model...\")\n",
|
|
"# for i, caption in tqdm(enumerate(captions[:time_test_size])):\n",
|
|
"# text_input_test = tokenizer([caption]).numpy().astype(\"int32\")\n",
|
|
"# start = time.time()\n",
|
|
"# _ = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test})\n",
|
|
"# end = time.time()\n",
|
|
"# times_quantized[i] = end - start\n",
|
|
"\n",
|
|
"# original_mean = times_unquantized.mean()\n",
|
|
"# original_std = times_unquantized.std()\n",
|
|
"# quantized_mean = times_quantized.mean()\n",
|
|
"# quantized_std = times_quantized.std()\n",
|
|
"\n",
|
|
"# print(f\"Original model: {original_mean:.6f} ± {original_std:.6f} seconds\")\n",
|
|
"# print(f\"Quantized model: {quantized_mean:.6f} ± {quantized_std:.6f} seconds\")\n",
|
|
"# print(f\"Speedup: {original_mean / quantized_mean:.2f}x\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!rm {clip_text_quantized_preprocessed_path}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Quantizing image model\n",
|
|
"\n",
|
|
"Eventually got it to roughly 0.996 similarity with the original model, at a reduction of 54MB, from 143 to 89MB. Also not bad, but since it's less of a reduction and the resulting embeddings will be stored permanently we decided not to use it. Uncomment code below to restart investigation if wanted."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# image_node_names = []\n",
|
|
"# image_matmul_nodes_names = []\n",
|
|
"# image_conv_nodes_names = []\n",
|
|
"# for node in clip_image_opt.graph.node:\n",
|
|
"# image_node_names.append(node.name)\n",
|
|
"# if node.op_type == \"MatMul\":\n",
|
|
"# image_matmul_nodes_names.append(node.name)\n",
|
|
"# if node.op_type == \"Conv\":\n",
|
|
"# image_conv_nodes_names.append(node.name)\n",
|
|
"# print(len(image_node_names))\n",
|
|
"# print(len(image_matmul_nodes_names))\n",
|
|
"# print(len(image_conv_nodes_names))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# clip_image_quantized_dynamic_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8_opt.onnx\"\n",
|
|
"# exclude = list(set(image_node_names[:100] + image_conv_nodes_names))\n",
|
|
"# quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)\n",
|
|
"\n",
|
|
"# mobileclip_image_quant_dyn_ort_sess = ort.InferenceSession(clip_image_quantized_dynamic_path)\n",
|
|
"# image_onnx_quant_dyn_emb = mobileclip_image_quant_dyn_ort_sess.run([\"output\"], {\"input\": image_singapore_onnx})[0][0]\n",
|
|
"# image_onnx_quant_dyn_emb /= norm(image_onnx_quant_dyn_emb)\n",
|
|
"# np.dot(image_onnx_quant_dyn_emb, image_emb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Debug quantizations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# exclude_amount = 50\n",
|
|
"# exclude_for_sure = image_node_names[:100] + image_node_names[225:260] + image_node_names[280:300] + image_node_names[430:480] + image_node_names[510:560] + image_node_names[650:]\n",
|
|
"\n",
|
|
"# image_test_quant = Image.open(\"../data/singapore.jpg\").convert('RGB')\n",
|
|
"# image_test_quant_onnx = np.array(image_test_quant)\n",
|
|
"\n",
|
|
"# clip_image_opt_sess = ort.InferenceSession(clip_image_opt_path)\n",
|
|
"# onnx_emb_quant_test = clip_image_opt_sess.run(None, {\"input\": image_test_quant_onnx})[0][0]\n",
|
|
"# onnx_emb_quant_test /= norm(onnx_emb_quant_test)\n",
|
|
"\n",
|
|
"\n",
|
|
"# for i in range(550, 600, exclude_amount):\n",
|
|
"# begin = i\n",
|
|
"# end = min(i+exclude_amount, len(image_node_names))\n",
|
|
"# exclude = list(set(exclude_for_sure + image_node_names[begin:end]))\n",
|
|
" \n",
|
|
"# clip_image_quantized_dynamic_debug_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8dyn_opt_debug.onnx\"\n",
|
|
"# quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_debug_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)\n",
|
|
"# mobileclip_image_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_image_quantized_dynamic_debug_path)\n",
|
|
"# image_onnx_quant_dyn_emb_debug = mobileclip_image_quant_dyn_ort_sess_debug.run([\"output\"], {\"input\": image_test_quant_onnx})[0][0]\n",
|
|
"# image_onnx_quant_dyn_emb_debug /= norm(image_onnx_quant_dyn_emb_debug)\n",
|
|
"# sim_debug = np.dot(image_onnx_quant_dyn_emb_debug, onnx_emb_quant_test)\n",
|
|
"# print(f\"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Float16 conversion for Image model\n",
|
|
"\n",
|
|
"https://onnxruntime.ai/docs/performance/model-optimizations/float16.html"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from onnxconverter_common import convert_float_to_float16"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"check_nodes_names = []\n",
|
|
"skip_nodes_names = []\n",
|
|
"try_image_model = onnx.load(clip_image_opt_path)\n",
|
|
"for node in try_image_model.graph.node:\n",
|
|
" check_nodes_names.append(node.name)\n",
|
|
"preprocess_nodes = check_nodes_names[:25]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_image_fp16 = convert_float_to_float16(try_image_model, keep_io_types=True, disable_shape_infer=True, node_block_list=preprocess_nodes)\n",
|
|
"clip_image_fp16_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_fp16.onnx\"\n",
|
|
"onnx.save(clip_image_fp16, clip_image_fp16_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test accuracy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_onnx_input = np.array(Image.open(\"../data/singapore.jpg\").convert('RGB'))\n",
|
|
"try_sess_options = ort.SessionOptions()\n",
|
|
"try_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
|
|
"# try_sess_options.inter_op_num_threads = 0\n",
|
|
"# try_sess_options.intra_op_num_threads = 0\n",
|
|
"# try_sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL\n",
|
|
"# try_sess_options.enable_profiling = True\n",
|
|
"# try_sess_options.log_severity_level = 0 # Verbose\n",
|
|
"clip_image_fp16_sess = ort.InferenceSession(clip_image_fp16_path, try_sess_options)\n",
|
|
"clip_image_sess = ort.InferenceSession(clip_image_opt_path, try_sess_options)\n",
|
|
"image_onnx_fp16_emb = clip_image_fp16_sess.run([\"output\"], {\"input\": image_onnx_input})[0][0]\n",
|
|
"image_onnx_fp16_emb /= norm(image_onnx_fp16_emb)\n",
|
|
"image_onnx_emb = clip_image_sess.run([\"output\"], {\"input\": image_onnx_input})[0][0]\n",
|
|
"image_onnx_emb /= norm(image_onnx_emb)\n",
|
|
"print(np.dot(image_onnx_fp16_emb, image_onnx_emb))\n",
|
|
"print(image_onnx_emb[0:5])\n",
|
|
"print(image_onnx_fp16_emb[0:5])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Test speed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"time_test_size = 100\n",
|
|
"\n",
|
|
"begin_time_fp16 = time.time()\n",
|
|
"for i in tqdm(range(time_test_size)):\n",
|
|
" _ = clip_image_fp16_sess.run([\"output\"], {\"input\": image_onnx_input})\n",
|
|
"end_time_fp16 = time.time()\n",
|
|
"time_fp16 = end_time_fp16 - begin_time_fp16\n",
|
|
"\n",
|
|
"begin_time_opt = time.time()\n",
|
|
"for i in tqdm(range(time_test_size)):\n",
|
|
" _ = clip_image_sess.run([\"output\"], {\"input\": image_onnx_input})\n",
|
|
"end_time_opt = time.time()\n",
|
|
"time_opt = end_time_opt - begin_time_opt\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"print(f\"Optimized model: {time_opt:.6f} seconds, so {time_opt / time_test_size:.6f} seconds per inference\")\n",
|
|
"print(f\"FP16 model: {time_fp16:.6f} seconds, so {time_fp16 / time_test_size:.6f} seconds per inference\")\n",
|
|
"print(f\"Speed difference FP16: {time_opt / time_fp16:.2f}x\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "ente_clip",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|