Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Comment thread
MatteoFasulo marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
- Fixed multiple typos in variable and method names, such as changing `includeGobalReferences` to `includeGlobalReferences` and `dicardedMappers` to `discardedMappers`
- Corrected method usage in `importDeeployState` to call `NetworkContext.importNetworkContext` instead of the incorrect method name
- Correctly return `signProp` from `setupDeployer` instead of hardcoding the value to `False` in `testMVP.py`
- Fixed `Unsqueeze` Op. when using ONNX opset 13 or higher (from attribute to input)

### Removed
- Delete outdated and unused `.gitlab-ci.yml` file
Expand Down
40 changes: 31 additions & 9 deletions Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,10 +940,19 @@ def __init__(self):

def parseNode(self, node: gs.Node) -> (bool):

ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
# ONNX v11: 'axes' is a node attribute
if 'axes' in node.attrs:
ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
# ONNX v13+: 'axes' becomes an input with the data
# Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
else:
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])

if ret:
self.operatorRepresentation['axes'] = node.attrs['axes']
if ret and 'axes' in node.attrs:
axes_attr = node.attrs['axes']
self.operatorRepresentation['axes'] = [int(axes_attr)] if isinstance(axes_attr, int) \
else [int(a) for a in axes_attr]
# For opset 13+, axes will be extracted from the second input in parseNodeCtxt

return ret

Expand All @@ -952,13 +961,26 @@ def parseNodeCtxt(self,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

inputs = ['data_in']
outputs = ['data_out']

for idx, inputNode in enumerate(node.inputs):
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
for idx, outputNode in enumerate(node.outputs):
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
if len(node.inputs) == 1:
inputs = ['data_in']
for idx, inputNode in enumerate(node.inputs):
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
for idx, outputNode in enumerate(node.outputs):
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
else:
data_in = ctxt.lookup(node.inputs[0].name)
data_out = ctxt.lookup(node.outputs[0].name)
self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['data_out'] = data_out.name
# axes must be a constant; extract values
axes_buf = ctxt.lookup(node.inputs[1].name)
assert hasattr(axes_buf, 'values'), "Unsqueeze: expected constant 'axes' input for opset 13+"
axes_vals = np.array(axes_buf.values).astype(int).flatten().tolist()
self.operatorRepresentation['axes'] = axes_vals
# Do not deploy the axes tensor
axes_buf._live = False
axes_buf._deploy = False

return ctxt, True

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading