Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wazevo(regalloc): refactors interference graph construction #1827

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions internal/engine/wazevo/backend/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func newMachine() backend.Machine {
}

func TestE2E(t *testing.T) {
const verbose = false
const verbose = true

type testCase struct {
name string
Expand Down Expand Up @@ -1518,33 +1518,33 @@ L1 (SSA Block: blk0):
add w8, w22, w8
add w0, w23, w8
ldr s8, #8; b 8; data.f32 1.000000
fmul s20, s0, s8
fmul s21, s0, s8
ldr s8, #8; b 8; data.f32 2.000000
fmul s19, s0, s8
fmul s20, s0, s8
ldr s8, #8; b 8; data.f32 3.000000
fmul s18, s0, s8
fmul s19, s0, s8
ldr s8, #8; b 8; data.f32 4.000000
fmul s17, s0, s8
fmul s18, s0, s8
ldr s8, #8; b 8; data.f32 5.000000
fmul s16, s0, s8
fmul s17, s0, s8
ldr s8, #8; b 8; data.f32 6.000000
fmul s15, s0, s8
fmul s16, s0, s8
ldr s8, #8; b 8; data.f32 7.000000
fmul s14, s0, s8
fmul s15, s0, s8
ldr s8, #8; b 8; data.f32 8.000000
fmul s13, s0, s8
fmul s14, s0, s8
ldr s8, #8; b 8; data.f32 9.000000
fmul s12, s0, s8
fmul s13, s0, s8
ldr s8, #8; b 8; data.f32 10.000000
fmul s11, s0, s8
fmul s12, s0, s8
ldr s8, #8; b 8; data.f32 11.000000
fmul s10, s0, s8
fmul s11, s0, s8
ldr s8, #8; b 8; data.f32 12.000000
fmul s9, s0, s8
fmul s10, s0, s8
ldr s8, #8; b 8; data.f32 13.000000
fmul s9, s0, s8
ldr s8, #8; b 8; data.f32 14.000000
fmul s8, s0, s8
ldr s21, #8; b 8; data.f32 14.000000
fmul s21, s0, s21
ldr s22, #8; b 8; data.f32 15.000000
fmul s22, s0, s22
ldr s23, #8; b 8; data.f32 16.000000
Expand All @@ -1562,8 +1562,7 @@ L1 (SSA Block: blk0):
fadd s24, s24, s25
fadd s23, s23, s24
fadd s22, s22, s23
fadd s21, s21, s22
fadd s8, s8, s21
fadd s8, s8, s22
fadd s8, s9, s8
fadd s8, s10, s8
fadd s8, s11, s8
Expand All @@ -1575,7 +1574,8 @@ L1 (SSA Block: blk0):
fadd s8, s17, s8
fadd s8, s18, s8
fadd s8, s19, s8
fadd s0, s20, s8
fadd s8, s20, s8
fadd s0, s21, s8
add sp, sp, #0x10
ldr q27, [sp], #0x10
ldr q26, [sp], #0x10
Expand Down
127 changes: 94 additions & 33 deletions internal/engine/wazevo/backend/regalloc/assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package regalloc

import (
"fmt"
"sort"

"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
)
Expand All @@ -10,35 +11,97 @@ import (
// This is called after coloring is done.
func (a *Allocator) assignRegisters(f Function) {
for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() {
a.assignRegistersPerBlock(f, blk, a.vRegIDToNode)
a.assignRegistersPerBlock(f, blk)
}
}

// assignRegistersPerBlock assigns real registers to virtual registers on each instruction in a block.
func (a *Allocator) assignRegistersPerBlock(f Function, blk Block, vRegIDToNode []*node) {
func (a *Allocator) assignRegistersPerBlock(f Function, blk Block) {
if wazevoapi.RegAllocLoggingEnabled {
fmt.Println("---------------------- assigning registers for block", blk.ID(), "----------------------")
}

blkID := blk.ID()
info := a.blockInfoAt(blk.ID())
a.aliveSet = resetMap(a.aliveSet)
for v := range info.liveIns {
n := a.getOrAllocateNode(v)
a.aliveSet[n] = struct{}{}
}

if !blk.Entry() {
for _, arg := range blk.BlockParams() {
n := a.getOrAllocateNode(arg)
a.aliveSet[n] = struct{}{}
}
}

var pc programCounter
for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
tree := a.blockInfos[blkID].intervalMng
a.assignRegistersPerInstr(f, pc, instr, vRegIDToNode, tree)
if wazevoapi.RegAllocLoggingEnabled {
fmt.Printf("--- handling %v ---\n", instr)
for alive := range a.aliveSet {
fmt.Println("\t", alive)
}
}

a.assignRegistersPerInstr(f, info, pc, instr)
pc += pcStride
}
}

func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, intervalMng *intervalManager) {
func (a *Allocator) collectOrderedActiveNodes(real bool) {
a.nodes1 = a.nodes1[:0]
for n := range a.aliveSet {
if real {
if n.assignedRealReg() == RealRegInvalid {
continue
}
} else {
if n.spill() || n.v.IsRealReg() {
continue
}
}
a.nodes1 = append(a.nodes1, n)
}
sort.Slice(a.nodes1, func(i, j int) bool {
return a.nodes1[i].v.ID() < a.nodes1[j].v.ID()
})
}

func (a *Allocator) updateAliveNodesByUse(info *blockInfo, pc programCounter, instr Instr) {
for _, use := range instr.Uses() {
n := a.vRegIDToNode[use.ID()]
v := n.v
if v.IsRealReg() {
delete(a.aliveSet, n)
} else {
if info.lastUses.Lookup(v) == pc {
if _, ok := info.liveOuts[v]; !ok {
delete(a.aliveSet, n)
}
}
}
}
}

func (a *Allocator) updateAliveNodesByDef(info *blockInfo, instr Instr) {
for _, def := range instr.Defs() {
n := a.vRegIDToNode[def.ID()]
v := n.v
if !v.IsRealReg() && info.lastUses.Lookup(v) < 0 {
if _, ok := info.liveOuts[v]; !ok {
continue
}
}
a.aliveSet[n] = struct{}{}
}
}

func (a *Allocator) assignRegistersPerInstr(f Function, info *blockInfo, pc programCounter, instr Instr) {
if indirect := instr.IsIndirectCall(); instr.IsCall() || indirect {
intervalMng.collectActiveNodes(
// To find the all the live registers "after" call, we need to add pcDefOffset for search.
pc+pcDefOffset,
&a.nodes1,
// Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since
// the real VRegs are already placed in the right registers at this point.
false,
)
a.updateAliveNodesByUse(info, pc, instr)
a.updateAliveNodesByDef(info, instr)
a.collectOrderedActiveNodes(false)
for _, active := range a.nodes1 {
if r := active.r; a.regInfo.isCallerSaved(r) {
v := active.v.SetRealReg(r)
Expand All @@ -48,15 +111,7 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
}
if indirect {
// Direct function calls do not need assignment, while indirect one needs the assignment on the function pointer.
a.assignIndirectCall(f, instr, vRegIDToNode)
}

if wazevoapi.RegAllocValidationEnabled {
for _, def := range instr.Defs() {
if !def.IsRealReg() {
panic(fmt.Sprintf("BUG: call/indirect call instruction must define only real registers: %s", def))
}
}
a.assignIndirectCall(f, instr)
}
return
} else if instr.IsReturn() {
Expand All @@ -72,7 +127,7 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
if wazevoapi.RegAllocLoggingEnabled {
fmt.Printf("%s uses %s(%d)\n", instr, u.RegType(), u.ID())
}
n := vRegIDToNode[u.ID()]
n := a.vRegIDToNode[u.ID()]
if !n.spill() {
instr.AssignUse(i, u.SetRealReg(n.r))
} else {
Expand All @@ -91,7 +146,7 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
fmt.Printf("%s defines %s(%d)\n", instr, d.RegType(), d.ID())
}

n := vRegIDToNode[d.ID()]
n := a.vRegIDToNode[d.ID()]
if !n.spill() {
instr.AssignDef(d.SetRealReg(n.r))
} else {
Expand All @@ -102,19 +157,23 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
panic("BUG: multiple def instructions must be special cased")
}

a.handleSpills(f, pc, instr, usesSpills, defSpill, intervalMng)
a.handleSpills(f, info, pc, instr, usesSpills, defSpill)
a.vs = usesSpills[:0] // for reuse.
}

func (a *Allocator) handleSpills(
f Function, pc programCounter, instr Instr,
usesSpills []VReg, defSpill VReg, intervalMng *intervalManager,
f Function, info *blockInfo, pc programCounter, instr Instr,
usesSpills []VReg, defSpill VReg,
) {
_usesSpills, _defSpill := len(usesSpills) > 0, defSpill.Valid()
switch {
case !_usesSpills && !_defSpill: // Nothing to do.
a.updateAliveNodesByUse(info, pc, instr)
a.updateAliveNodesByDef(info, instr)
case !_usesSpills && _defSpill: // Only definition is spilled.
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
a.updateAliveNodesByUse(info, pc, instr)
a.updateAliveNodesByDef(info, instr)
a.collectOrderedActiveNodes(true)
a.spillHandler.init(a.nodes1, instr)

r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
Expand All @@ -130,7 +189,8 @@ func (a *Allocator) handleSpills(
f.StoreRegisterAfter(defSpill, instr)

case _usesSpills:
intervalMng.collectActiveNodes(pc, &a.nodes1, true)
a.updateAliveNodesByUse(info, pc, instr)
a.collectOrderedActiveNodes(true)
a.spillHandler.init(a.nodes1, instr)

var evicted [3]*node
Expand Down Expand Up @@ -174,7 +234,8 @@ func (a *Allocator) handleSpills(

if !defSpill.IsRealReg() {
// This case, the destination register type is different from the source registers.
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
a.updateAliveNodesByDef(info, instr)
a.collectOrderedActiveNodes(true)
a.spillHandler.init(a.nodes1, instr)
r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
if evictedNode != nil {
Expand All @@ -191,7 +252,7 @@ func (a *Allocator) handleSpills(
}
}

func (a *Allocator) assignIndirectCall(f Function, instr Instr, vRegIDToNode []*node) {
func (a *Allocator) assignIndirectCall(f Function, instr Instr) {
a.nodes1 = a.nodes1[:0]
uses := instr.Uses()
if wazevoapi.RegAllocValidationEnabled {
Expand All @@ -218,7 +279,7 @@ func (a *Allocator) assignIndirectCall(f Function, instr Instr, vRegIDToNode []*
panic(fmt.Sprintf("BUG: function pointer for indirect call must be an integer register: %s", v))
}

n := vRegIDToNode[v.ID()]
n := a.vRegIDToNode[v.ID()]
if n.spill() {
// If the function pointer is spilled, we need to reload it to a register.
// But at this point, all the caller-saved registers are saved, we can use a callee-saved register to reload.
Expand Down
2 changes: 2 additions & 0 deletions internal/engine/wazevo/backend/regalloc/assign_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build tmp

package regalloc

import (
Expand Down
59 changes: 8 additions & 51 deletions internal/engine/wazevo/backend/regalloc/coloring.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,6 @@ import (
"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
)

// buildNeighbors builds the neighbors for each node in the interference graph.
func (a *Allocator) buildNeighbors() {
allocated := a.nodePool.Allocated()
if diff := allocated - len(a.dedup); diff > 0 {
a.dedup = append(a.dedup, make([]bool, diff+1)...)
}
for i := 0; i < allocated; i++ {
n := a.nodePool.View(i)
a.buildNeighborsFor(n)
}
}

func (a *Allocator) buildNeighborsFor(n *node) {
for _, r := range n.ranges {
// Collects all the nodes that are in the same range.
for _, neighbor := range r.nodes {
neighborID := neighbor.id
if neighbor.v.RegType() != n.v.RegType() {
continue
}
if neighbor != n && !a.dedup[neighborID] {
n.neighbors = append(n.neighbors, neighbor)
a.dedup[neighborID] = true
}
}

// And also collects all the nodes that are in the neighbor ranges.
for _, neighborInterval := range r.neighbors {
for _, neighbor := range neighborInterval.nodes {
if neighbor.v.RegType() != n.v.RegType() {
continue
}
neighborID := neighbor.id
if neighbor != n && !a.dedup[neighborID] {
n.neighbors = append(n.neighbors, neighbor)
a.dedup[neighborID] = true
}
}
}
}
// Reset for the next iteration.
for _, neighbor := range n.neighbors {
a.dedup[neighbor.id] = false
}
}

// coloring does the graph coloring for both RegType(s).
// Since the graphs are disjoint per RegType, we do it by RegType separately.
func (a *Allocator) coloring() {
Expand Down Expand Up @@ -88,11 +42,6 @@ func (a *Allocator) coloringFor(allocatable []RealReg) {

numAllocatable := len(allocatable)

// Initialize the degree for each node which is defined as the number of neighbors.
for _, n := range degreeSortedNodes {
n.degree = len(n.neighbors)
}

// Sort the nodes by the current degree.
sort.SliceStable(degreeSortedNodes, func(i, j int) bool {
return degreeSortedNodes[i].degree < degreeSortedNodes[j].degree
Expand Down Expand Up @@ -133,6 +82,7 @@ func (a *Allocator) coloringFor(allocatable []RealReg) {
for len(popTargetQueue) > 0 {
top := popTargetQueue[0]
popTargetQueue = popTargetQueue[1:]

for _, neighbor := range top.neighbors {
neighbor.degree--
if neighbor.degree < numAllocatable {
Expand Down Expand Up @@ -162,6 +112,13 @@ func (a *Allocator) coloringFor(allocatable []RealReg) {

// Gather already used colors.
for _, neighbor := range n.neighbors {
if wazevoapi.RegAllocLoggingEnabled {
if neighbor.r == RealRegInvalid {
fmt.Println("\tneighbor: ", neighbor, " (not colored yet)")
} else {
fmt.Println("\tneighbor: ", neighbor, " (colored)", a.regInfo.RealRegName(neighbor.r))
}
}
if neighborColor := neighbor.r; neighborColor != RealRegInvalid {
neighborColorsSet[neighborColor] = true
}
Expand Down
Loading
Loading