From ffdb3c9629a91f228bf1688382fd18f77e2b79d9 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Fri, 18 Oct 2024 17:15:13 +0530 Subject: [PATCH] [infra] Solve yolo splits issue in model creation --- infra/ml/YOLOv5Face/yoloface_onnx.ipynb | 105 ++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 7 deletions(-) diff --git a/infra/ml/YOLOv5Face/yoloface_onnx.ipynb b/infra/ml/YOLOv5Face/yoloface_onnx.ipynb index b7b157be5f..f89fc1c51c 100644 --- a/infra/ml/YOLOv5Face/yoloface_onnx.ipynb +++ b/infra/ml/YOLOv5Face/yoloface_onnx.ipynb @@ -734,6 +734,96 @@ "opt_session = ort.InferenceSession(onnx_model_sim_path, opt_sess_options)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prevent splits initializer issue\n", + "\n", + "For some weird reason the model can give issues on iOS when there's an initializer named \"splits\". \n", + "So to prevent that we check and rename any such initializer" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "current_model = onnx.load(onnx_model_opt_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def find_duplicates(name_list):\n", + " seen = set()\n", + " duplicates = set()\n", + " \n", + " for name in name_list:\n", + " if name in seen:\n", + " duplicates.add(name)\n", + " else:\n", + " seen.add(name)\n", + " \n", + " return list(duplicates)\n", + "\n", + "# Get the list of initializers\n", + "initializers = current_model.graph.initializer\n", + "init_names = [init.name for init in initializers]\n", + "\n", + "# If you want to store the initializers and their names in a dictionary\n", + "initializer_dict = {init.name: init for init in initializers}\n", + "init_names = [init.name for init in initializers]\n", + "\n", + "print(f\"splits initializer: \\n {initializer_dict[\"splits\"]}\")\n", + "\n", + "duplicate_names = find_duplicates(init_names)\n", + "\n", + "print(\"Duplicate names:\", duplicate_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_initializer(model, old_name, new_name):\n", + " for initializer in model.graph.initializer:\n", + " if initializer.name == old_name:\n", + " initializer.name = new_name\n", + " break\n", + " \n", + " # Update any references to this initializer in the graph inputs\n", + " for input in model.graph.input:\n", + " if input.name == old_name:\n", + " input.name = new_name\n", + " \n", + " # Update references in nodes\n", + " for node in model.graph.node:\n", + " for i, input_name in enumerate(node.input):\n", + " if input_name == old_name:\n", + " node.input[i] = new_name" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "rename_initializer(current_model, \"splits\", \"splits_initializer_unique\")\n", + "\n", + "# Save the modified model\n", + "onnx_model_opt_with_splits_path = onnx_model_opt_path\n", + "onnx_model_opt_path = onnx_model_opt_path[:-5] + \"_nosplits.onnx\"\n", + "onnx.save(current_model, onnx_model_opt_path)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -745,7 +835,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -759,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -768,7 +858,8 @@ "!rm {onnx_model_split_path}\n", "!rm {onnx_model_nms_path}\n", "!rm {onnx_model_prepostpro_path}\n", - "!rm {onnx_model_sim_path}" + "!rm {onnx_model_sim_path}\n", + "!rm {onnx_model_opt_with_splits_path}" ] }, { @@ -780,7 +871,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -790,7 +881,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -874,7 +965,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -903,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [