- Thread Author
- #1
I am trying to convert [dert model][1] to tensor flow using onnx. I converted the model using torch.onnx.export with opset_version=12.(which produces a detr.onnx file)
Then I tried to convert the onnx file to tensorflow model using [this example][2]. I added onnx.check_model line to make sure model is loaded correctly.
This code raises an exception when it reaches tf_rep.export_graph('./model.pb') line.
message of exception :
Then I tried to convert the onnx file to tensorflow model using [this example][2]. I added onnx.check_model line to make sure model is loaded correctly.
Code:
import math
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img = transform(im).unsqueeze(0)
onnx_model = onnx.load('./detr.onnx')
result = onnx.checker.check_model(onnx_model)
tf_rep = prepare(onnx_model)
tf_rep.export_graph('./model.pb')
message of exception :
Code:
KeyError Traceback (most recent call last)
Cell In[19], line 26
23 result = onnx.checker.check_model(onnx_model)
25 tf_rep = prepare(onnx_model)
---> 26 tf_rep.export_graph('./model.pb')
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\onnx_tf\backend_rep.py:143, in TensorflowRep.export_graph(self, path)
129 """Export backend representation to a Tensorflow proto file.
130
131 This function obtains the graph proto corresponding to the ONNX
(...)
137 :returns: none.
138 """
139 self.tf_module.is_export = True
140 tf.saved_model.save(
141 self.tf_module,
142 path,
--> 143 signatures=self.tf_module.__call__.get_concrete_function(
144 **self.signatures))
145 self.tf_module.is_export = False
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\eager\def_function.py:1239, in Function.get_concrete_function(self, *args, **kwargs)
1237 def get_concrete_function(self, *args, **kwargs):
1238 # Implements GenericFunction.get_concrete_function.
-> 1239 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1240 concrete._garbage_collector.release() # pylint: disable=protected-access
1241 return concrete
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\eager\def_function.py:1219, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
1217 if self._stateful_fn is None:
1218 initializers = []
-> 1219 self._initialize(args, kwargs, add_initializers_to=initializers)
1220 self._initialize_uninitialized_variables(initializers)
1222 if self._created_variables:
1223 # In this case we have created variables on the first call, so we run the
1224 # defunned version which is guaranteed to never create variables.
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\eager\def_function.py:785, in Function._initialize(self, args, kwds, add_initializers_to)
782 self._lifted_initializer_graph = lifted_initializer_graph
783 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
784 self._concrete_stateful_fn = (
--> 785 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
786 *args, **kwds))
788 def invalid_creator_scope(*unused_args, **unused_kwds):
789 """Disables variable creation."""
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\eager\function.py:2523, in Function._get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2521 args, kwargs = None, None
2522 with self._lock:
-> 2523 graph_function, _ = self._maybe_define_function(args, kwargs)
2524 return graph_function
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\eager\function.py:2760, in Function._maybe_define_function(self, args, kwargs)
2758 # Only get placeholders for arguments, not captures
2759 args, kwargs = placeholder_dict["args"]
-> 2760 graph_function = self._create_graph_function(args, kwargs)
2762 graph_capture_container = graph_function.graph._capture_func_lib # pylint: disable=protected-access
2763 # Maintain the list of all captures