TVM TIR中有很多类型的Node。这里收录了全部可打印的类型:
- StmtNode
- AnyNode
- PrimExprNode
- TypeNode
- PrimFuncNode
- IRModuleNode
- ArrayNode
- IterVarNode
- RangeNode
- BufferNode
- DataProducerNode
- StringObj
- BufferRegionNode
- TargetNode
为了更好地理解这些Node。一方面可以去看它的数据结构,另一方面,也希望可以把它们和TIR代码对应起来。
例1:
@main = primfn(n: int32) -> () {
producer_realize(compute: DataProducer("compute", float32, [n]), [0:n], True, {
for (i.outer: int32, 0, floordiv((n + 3), 4)) {
for (i.inner: int32, 0, 4) {
if @tir.likely(((i.inner + (i.outer*4)) < n), dtype=bool) {
compute_1: DataProducer("compute", float32, [n])[(i.inner + (i.outer*4))] = (A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
}
}
}
})
}
分解出每一个元素:
PrimExprNode :
(i.outer*4)
A[(i.inner + (i.outer*4))]
({n|n>=0} + 3)
i.outer
B[(i.inner + (i.outer*4))]
(i.inner + (i.outer*4))
tir.likely(((i.inner + (i.outer*4)) < {n|n>=0}))
((i.inner + (i.outer*4)) < {n|n>=0})
0
(bool)1
{n|n>=0}
i.inner
3
4
(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
floordiv(({n|n>=0} + 3), 4)
TypeNode :
TupleTypeNode([])
int32
StmtNode :
for (i.outer, 0, floordiv(({n|n>=0} + 3), 4)) {
for (i.inner, 0, 4) {
if (tir.likely(((i.inner + (i.outer*4)) < {n|n>=0}))) {
compute[(i.inner + (i.outer*4))] =(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
}
}
}
for (i.inner, 0, 4) {
if (tir.likely(((i.inner + (i.outer*4)) < {n|n>=0}))) {
compute[(i.inner + (i.outer*4))] =(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
}
}
compute[(i.inner + (i.outer*4))] =(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
if (tir.likely(((i.inner + (i.outer*4)) < {n|n>=0}))) {
compute[(i.inner + (i.outer*4))] =(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
}
producer_realize compute([0, {n|n>=0}]) {
for (i.outer, 0, floordiv(({n|n>=0} + 3), 4)) {
for (i.inner, 0, 4) {
if (tir.likely(((i.inner + (i.outer*4)) < {n|n>=0}))) {
compute[(i.inner + (i.outer*4))] =(A[(i.inner + (i.outer*4))] + B[(i.inner + (i.outer*4))])
}
}
}
}
DataProducerNode :
Tensor(shape=[{n|n>=0}], op.name=compute)
StringObj :
"compute"
ArrayNode :
[{n|n>=0}]
[(i.inner + (i.outer*4))]
[range(min=0, ext={n|n>=0})]
RangeNode :
range(min=0, ext={n|n>=0})
注意,子类型的Node统一以其父类的名字展示。比如ForNode继承自StmtNode,上面的StmtNode中就包含了ForNode。