Skip to content

Commit 9907a3e

Browse files
lara-hdrfacebook-github-bot
authored andcommitted
Update Argmin/Argmax ONNX Export (#38329)
Summary: Update Argmin/Argmax ONNX export in opset 12 to export with "select_last_index", and export correctly cases where the same value appears multiple time in the input tensor. Pull Request resolved: #38329 Reviewed By: hl475 Differential Revision: D21613799 Pulled By: houseroad fbshipit-source-id: 4597e23561f444c4e56d30c735dae7e9a8a41c5e
1 parent cbd0adc commit 9907a3e

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

+26
Original file line numberDiff line numberDiff line change
@@ -2045,6 +2045,32 @@ def forward(self, input, other):
20452045
y = torch.randint(10, (2, 4, 5))
20462046
self.run_test(MatmulModel(), (x, y))
20472047

2048+
def _argmin_argmax_model(self, input):
2049+
class ArgminArgmaxModel(torch.nn.Module):
2050+
def forward(self, input):
2051+
return torch.argmin(input), \
2052+
torch.argmax(input), \
2053+
torch.argmin(input, keepdim=True), \
2054+
torch.argmax(input, keepdim=True)
2055+
2056+
self.run_test(ArgminArgmaxModel(), input)
2057+
2058+
def test_argmin_argmax(self):
2059+
input = torch.randn(7, 3, 5)
2060+
self._argmin_argmax_model(input)
2061+
2062+
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
2063+
# "select_last_index" was added in opset 12 to deal with corner case where the
2064+
# same value appears multiple times in the tensor
2065+
@skipIfUnsupportedMinOpsetVersion(12)
2066+
def test_argmin_argmax_select_last_index(self):
2067+
input = torch.tensor([[1., 2., 3.],
2068+
[1., 1., 2.]])
2069+
self._argmin_argmax_model(input)
2070+
2071+
input = torch.ones(7, 3, 5)
2072+
self._argmin_argmax_model(input)
2073+
20482074
def test_view(self):
20492075
class ViewModel(torch.nn.Module):
20502076
def forward(self, input):

torch/onnx/symbolic_opset12.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import torch.onnx.symbolic_helper as sym_help
5-
from torch.onnx.symbolic_helper import parse_args
5+
from torch.onnx.symbolic_helper import parse_args, _parse_arg
66

77

88
# EDITING THIS FILE? READ THIS FIRST!
@@ -62,6 +62,28 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index):
6262
return nll_loss(g, self, target, weight, reduction, ignore_index)
6363

6464

65+
def argmax(g, input, dim, keepdim):
66+
if sym_help._is_none(dim):
67+
from torch.onnx.symbolic_opset9 import reshape
68+
flattened = reshape(g, input, (-1,))
69+
return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True)
70+
else:
71+
dim = _parse_arg(dim, 'i')
72+
keepdim = _parse_arg(keepdim, 'i')
73+
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True)
74+
75+
76+
def argmin(g, input, dim, keepdim):
77+
if sym_help._is_none(dim):
78+
from torch.onnx.symbolic_opset9 import reshape
79+
flattened = reshape(g, input, (-1,))
80+
return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True)
81+
else:
82+
dim = _parse_arg(dim, 'i')
83+
keepdim = _parse_arg(keepdim, 'i')
84+
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True)
85+
86+
6587
def pow(g, self, exponent):
6688
return g.op("Pow", self, exponent)
6789

0 commit comments

Comments
 (0)