[infra] Solve yolo splits issue in model creation

This commit is contained in:
laurenspriem 2024-10-18 17:15:13 +05:30
parent 442c20b175
commit ffdb3c9629

View File

@ -734,6 +734,96 @@
"opt_session = ort.InferenceSession(onnx_model_sim_path, opt_sess_options)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prevent splits initializer issue\n",
"\n",
"For some weird reason the model can give issues on iOS when there's an initializer named \"splits\". \n",
"So to prevent that we check and rename any such initializer"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"current_model = onnx.load(onnx_model_opt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def find_duplicates(name_list):\n",
" seen = set()\n",
" duplicates = set()\n",
" \n",
" for name in name_list:\n",
" if name in seen:\n",
" duplicates.add(name)\n",
" else:\n",
" seen.add(name)\n",
" \n",
" return list(duplicates)\n",
"\n",
"# Get the list of initializers\n",
"initializers = current_model.graph.initializer\n",
"init_names = [init.name for init in initializers]\n",
"\n",
"# If you want to store the initializers and their names in a dictionary\n",
"initializer_dict = {init.name: init for init in initializers}\n",
"init_names = [init.name for init in initializers]\n",
"\n",
"print(f\"splits initializer: \\n {initializer_dict[\"splits\"]}\")\n",
"\n",
"duplicate_names = find_duplicates(init_names)\n",
"\n",
"print(\"Duplicate names:\", duplicate_names)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def rename_initializer(model, old_name, new_name):\n",
" for initializer in model.graph.initializer:\n",
" if initializer.name == old_name:\n",
" initializer.name = new_name\n",
" break\n",
" \n",
" # Update any references to this initializer in the graph inputs\n",
" for input in model.graph.input:\n",
" if input.name == old_name:\n",
" input.name = new_name\n",
" \n",
" # Update references in nodes\n",
" for node in model.graph.node:\n",
" for i, input_name in enumerate(node.input):\n",
" if input_name == old_name:\n",
" node.input[i] = new_name"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"rename_initializer(current_model, \"splits\", \"splits_initializer_unique\")\n",
"\n",
"# Save the modified model\n",
"onnx_model_opt_with_splits_path = onnx_model_opt_path\n",
"onnx_model_opt_path = onnx_model_opt_path[:-5] + \"_nosplits.onnx\"\n",
"onnx.save(current_model, onnx_model_opt_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -745,7 +835,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -759,7 +849,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -768,7 +858,8 @@
"!rm {onnx_model_split_path}\n",
"!rm {onnx_model_nms_path}\n",
"!rm {onnx_model_prepostpro_path}\n",
"!rm {onnx_model_sim_path}"
"!rm {onnx_model_sim_path}\n",
"!rm {onnx_model_opt_with_splits_path}"
]
},
{
@ -780,7 +871,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -790,7 +881,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@ -874,7 +965,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -903,7 +994,7 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [