Skip to content

Commit c28aeb2

Browse files
committed
fix onnx_text_plot_tree
1 parent 674eb27 commit c28aeb2

File tree

3 files changed

+64
-19
lines changed

3 files changed

+64
-19
lines changed
Binary file not shown.

_unittests/ut_plotting/test_text_plot.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_onnx_text_plot_tree_reg(self):
5252
onx = to_onnx(clr, X)
5353
res = onnx_text_plot_tree(onx.graph.node[0])
5454
self.assertIn("treeid=0", res)
55-
self.assertIn(" T y=", res)
55+
self.assertIn(" +f", res)
5656

5757
def test_onnx_text_plot_tree_cls(self):
5858
iris = load_iris()
@@ -62,20 +62,43 @@ def test_onnx_text_plot_tree_cls(self):
6262
onx = to_onnx(clr, X)
6363
res = onnx_text_plot_tree(onx.graph.node[0])
6464
self.assertIn("treeid=0", res)
65-
self.assertIn(" T y=", res)
65+
self.assertIn(" +f 0:", res)
6666
self.assertIn("n_classes=3", res)
6767

6868
def test_onnx_text_plot_tree_cls_2(self):
69-
iris = load_iris()
70-
X_train, y_train = iris.data.astype(numpy.float32), iris.target
71-
clr = DecisionTreeClassifier()
72-
clr.fit(X_train, y_train)
73-
model_def = to_onnx(
74-
clr, X_train.astype(numpy.float32), options={"zipmap": False}
69+
this = os.path.join(
70+
os.path.dirname(__file__), "data", "onnx_text_plot_tree_cls_2.onnx"
7571
)
72+
with open(this, "rb") as f:
73+
model_def = load(f)
7674
res = onnx_text_plot_tree(model_def.graph.node[0])
7775
self.assertIn("n_classes=3", res)
78-
print(res)
76+
expected = textwrap.dedent(
77+
"""
78+
n_classes=3
79+
n_trees=1
80+
----
81+
treeid=0
82+
n X2 <= 2.4499998
83+
-n X3 <= 1.75
84+
-n X2 <= 4.85
85+
-f 0:0 1:0 2:1
86+
+n X0 <= 5.95
87+
-f 0:0 1:0 2:1
88+
+f 0:0 1:1 2:0
89+
+n X2 <= 4.95
90+
-n X3 <= 1.55
91+
-n X0 <= 6.95
92+
-f 0:0 1:0 2:1
93+
+f 0:0 1:1 2:0
94+
+f 0:0 1:0 2:1
95+
+n X3 <= 1.65
96+
-f 0:0 1:0 2:1
97+
+f 0:0 1:1 2:0
98+
+f 0:1 1:0 2:0
99+
"""
100+
).strip(" \n\r")
101+
self.assertEqual(expected, res.strip(" \n\r"))
79102

80103
@ignore_warnings((UserWarning, FutureWarning))
81104
def test_onnx_simple_text_plot_kmeans(self):

onnx_array_api/plotting/text_plot.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def _rule(r):
2424
raise ValueError(f"Unexpected rule {r!r}.")
2525

2626

27+
def _number2str(i):
28+
if isinstance(i, int):
29+
return str(i)
30+
if int(i) == i:
31+
return str(int(i))
32+
return f"{i:1.2f}"
33+
34+
2735
def onnx_text_plot_tree(node):
2836
"""
2937
Gives a textual representation of a tree ensemble.
@@ -61,18 +69,32 @@ def __init__(self, i, atts):
6169
setattr(self, k, v[i])
6270
self.depth = 0
6371
self.true_false = ""
72+
self.targets = []
73+
74+
def append_target(self, tid, weight):
75+
self.targets.append(dict(target_id=tid, weight=weight))
6476

6577
def process_node(self):
6678
"node to string"
6779
if self.nodes_modes == "LEAF":
68-
text = "%s y=%r f=%r i=%r" % (
69-
self.true_false,
70-
self.target_weights,
71-
self.target_ids,
72-
self.target_nodeids,
73-
)
80+
if len(self.targets) == 0:
81+
text = f"{self.true_false}f"
82+
elif len(self.targets) == 1:
83+
t = self.targets[0]
84+
text = (
85+
f"{self.true_false}f "
86+
f"{t['target_id']}:{_number2str(t['weight'])}"
87+
)
88+
else:
89+
ts = " ".join(
90+
map(
91+
lambda t: f"{t['target_id']}:{_number2str(t['weight'])}",
92+
self.targets,
93+
)
94+
)
95+
text = f"{self.true_false}f {ts}"
7496
else:
75-
text = "%s X%d %s %r" % (
97+
text = "%sn X%d %s %r" % (
7698
self.true_false,
7799
self.nodes_featureids,
78100
_rule(self.nodes_modes),
@@ -115,7 +137,7 @@ def process_tree(atts, treeid):
115137
idn = short[f"{prefix}_nodeids"][i]
116138
node = nodes[idn]
117139
node.append_target(
118-
id=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
140+
tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
119141
)
120142

121143
def iterate(nodes, node, depth=0, true_false=""):
@@ -127,14 +149,14 @@ def iterate(nodes, node, depth=0, true_false=""):
127149
nodes,
128150
nodes[node.nodes_falsenodeids],
129151
depth=depth + 1,
130-
true_false="F",
152+
true_false="-",
131153
):
132154
yield n
133155
for n in iterate(
134156
nodes,
135157
nodes[node.nodes_truenodeids],
136158
depth=depth + 1,
137-
true_false="T",
159+
true_false="+",
138160
):
139161
yield n
140162

0 commit comments

Comments
 (0)