mirror of
https://github.com/ente-io/ente.git
synced 2025-08-08 07:28:26 +00:00
[infra] Solve yolo splits issue in model creation
This commit is contained in:
parent
442c20b175
commit
ffdb3c9629
@ -734,6 +734,96 @@
|
||||
"opt_session = ort.InferenceSession(onnx_model_sim_path, opt_sess_options)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prevent splits initializer issue\n",
|
||||
"\n",
|
||||
"For some weird reason the model can give issues on iOS when there's an initializer named \"splits\". \n",
|
||||
"So to prevent that we check and rename any such initializer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"current_model = onnx.load(onnx_model_opt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def find_duplicates(name_list):\n",
|
||||
" seen = set()\n",
|
||||
" duplicates = set()\n",
|
||||
" \n",
|
||||
" for name in name_list:\n",
|
||||
" if name in seen:\n",
|
||||
" duplicates.add(name)\n",
|
||||
" else:\n",
|
||||
" seen.add(name)\n",
|
||||
" \n",
|
||||
" return list(duplicates)\n",
|
||||
"\n",
|
||||
"# Get the list of initializers\n",
|
||||
"initializers = current_model.graph.initializer\n",
|
||||
"init_names = [init.name for init in initializers]\n",
|
||||
"\n",
|
||||
"# If you want to store the initializers and their names in a dictionary\n",
|
||||
"initializer_dict = {init.name: init for init in initializers}\n",
|
||||
"init_names = [init.name for init in initializers]\n",
|
||||
"\n",
|
||||
"print(f\"splits initializer: \\n {initializer_dict[\"splits\"]}\")\n",
|
||||
"\n",
|
||||
"duplicate_names = find_duplicates(init_names)\n",
|
||||
"\n",
|
||||
"print(\"Duplicate names:\", duplicate_names)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rename_initializer(model, old_name, new_name):\n",
|
||||
" for initializer in model.graph.initializer:\n",
|
||||
" if initializer.name == old_name:\n",
|
||||
" initializer.name = new_name\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" # Update any references to this initializer in the graph inputs\n",
|
||||
" for input in model.graph.input:\n",
|
||||
" if input.name == old_name:\n",
|
||||
" input.name = new_name\n",
|
||||
" \n",
|
||||
" # Update references in nodes\n",
|
||||
" for node in model.graph.node:\n",
|
||||
" for i, input_name in enumerate(node.input):\n",
|
||||
" if input_name == old_name:\n",
|
||||
" node.input[i] = new_name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rename_initializer(current_model, \"splits\", \"splits_initializer_unique\")\n",
|
||||
"\n",
|
||||
"# Save the modified model\n",
|
||||
"onnx_model_opt_with_splits_path = onnx_model_opt_path\n",
|
||||
"onnx_model_opt_path = onnx_model_opt_path[:-5] + \"_nosplits.onnx\"\n",
|
||||
"onnx.save(current_model, onnx_model_opt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -745,7 +835,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -759,7 +849,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -768,7 +858,8 @@
|
||||
"!rm {onnx_model_split_path}\n",
|
||||
"!rm {onnx_model_nms_path}\n",
|
||||
"!rm {onnx_model_prepostpro_path}\n",
|
||||
"!rm {onnx_model_sim_path}"
|
||||
"!rm {onnx_model_sim_path}\n",
|
||||
"!rm {onnx_model_opt_with_splits_path}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -780,7 +871,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -790,7 +881,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -874,7 +965,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -903,7 +994,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
Loading…
x
Reference in New Issue
Block a user