Skip to content

Commit 5ad3fc9

Browse files
committed
add basic codegen test, fix binary tree rebalancing
1 parent c7ac24f commit 5ad3fc9

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

src/cmd/compile/internal/ssa/reassociate.go

+32-31
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44

55
package ssa
66

7-
import (
8-
"fmt"
9-
"sort"
10-
)
11-
12-
// balanceExprTree repurposes all nodes and leafs into a
13-
// balanced expression tree
7+
// balanceExprTree repurposes all nodes and leafs into a well-balanced expression tree.
8+
// It doesn't truly balance the tree in the sense of a BST, rather it
9+
// prioritizes pairing up innermost (rightmost) expressions and their results and only
10+
// pairing results of outermost (leftmost) expressions up with them when no other nice pairing exists
1411
func balanceExprTree(v *Value, visited map[*Value]bool, nodes, leafs []*Value) {
1512
// reset all arguments of nodes to help rebalancing
1613
for i, n := range nodes {
@@ -34,27 +31,27 @@ func balanceExprTree(v *Value, visited map[*Value]bool, nodes, leafs []*Value) {
3431
nodes[i], nodes[j] = nodes[j], nodes[i]
3532
}
3633

37-
// push all leafs which are constants as far off to the
38-
// right as possible to give the constant folder more opportunities
39-
sort.Slice(leafs, func(i, j int) bool {
40-
switch leafs[j].Op {
41-
case OpConst8, OpConst16, OpConst32, OpConst64:
42-
return false
43-
default:
44-
return true
34+
// rebuild expression trees from the bottom up, prioritizing
35+
// right grouping.
36+
// if the number of leaves is not even, skip the first leaf
37+
// and add it to be paired up later
38+
i := 0
39+
subTrees := leafs
40+
for len(subTrees) != 1 {
41+
nextSubTrees := make([]*Value, 0, (len(subTrees)+1)/2)
42+
43+
start := len(subTrees)%2
44+
if start != 0 {
45+
nextSubTrees = append(nextSubTrees, subTrees[0])
4546
}
46-
})
47-
48-
// build tree in reverse topological order
49-
for i := 0; i < len(nodes); i++ {
50-
if len(leafs) < 2 { // we need at least two leafs per node, balance went very wrong
51-
panic(fmt.Sprint("leafs needs to be >= 2, got", len(leafs)))
47+
48+
for j := start; j < len(subTrees)-1; j+=2 {
49+
nodes[i].AddArg2(subTrees[j], subTrees[j+1])
50+
nextSubTrees = append(nextSubTrees, nodes[i])
51+
i++
5252
}
53-
54-
// Take two leaves out and attach them to a node,
55-
// use the node as a new leaf in the "next layer" of the tree
56-
nodes[i].AddArg2(leafs[0], leafs[1])
57-
leafs = append(leafs[2:], nodes[i])
53+
54+
subTrees = nextSubTrees
5855
}
5956
}
6057

@@ -72,7 +69,7 @@ func isOr(op Op) bool {
7269
//
7370
// (l | l << 8 | l << 18 | l << 24)
7471
//
75-
// which cannot be rebalanced or else it won't fire rewrite rules
72+
// which cannot be rebalanced or else it won't fire load widening rewrite rules
7673
func probablyMemcombine(op Op, leafs []*Value) bool {
7774
if !isOr(op) {
7875
return false
@@ -89,7 +86,11 @@ func probablyMemcombine(op Op, leafs []*Value) bool {
8986
}
9087
}
9188

92-
return lshCount == len(leafs)-1
89+
// there are a few algorithms in the std lib expressed as two 32 bit loads
90+
// which can get turned into a 64 bit load
91+
// conservatively estimate that if there are more shifts than not then it is
92+
// some sort of load waiting to be widened
93+
return lshCount > len(leafs)/2
9394
}
9495

9596
// rebalance balances associative computation to better help CPU instruction pipelining (#49331)
@@ -145,7 +146,7 @@ func rebalance(v *Value, visited map[*Value]bool) {
145146
}
146147

147148
// we need at least 4 leafs for this expression to be rebalanceable,
148-
// and we can't balance a potential load widening (memcombine)
149+
// and we can't balance a potential load widening (see memcombine)
149150
if len(leafs) < 4 || probablyMemcombine(v.Op, leafs) {
150151
return
151152
}
@@ -154,8 +155,8 @@ func rebalance(v *Value, visited map[*Value]bool) {
154155
}
155156

156157
// reassociate balances trees of commutative computation
157-
// to better group expressions for better constant folding,
158-
// cse, etc.
158+
// to better group expressions to expose easy optimizations in
159+
// cse, cancelling/counting/factoring expressions, etc.
159160
func reassociate(f *Func) {
160161
visited := make(map[*Value]bool)
161162

test/codegen/reassociate.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// asmcheck
2+
3+
package codegen
4+
5+
// reassociateAddition expects very specific sequence of registers
6+
// of the form:
7+
// R2 += R3
8+
// R1 += R0
9+
// R1 += R2
10+
func reassociateAddition(a, b, c, d int) int {
11+
// arm64:`ADD\tR2,\sR3,\sR2`
12+
x := b + a
13+
// arm64:`ADD\tR0,\sR1,\sR1`
14+
y := x + c
15+
// arm64:`ADD\tR1,\sR2,\sR0`
16+
z := y + d
17+
return z
18+
}

0 commit comments

Comments
 (0)