Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wazevo(regalloc): fixes phi spill logic #1846

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 27 additions & 43 deletions internal/engine/wazevo/backend/regalloc/regalloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func NewAllocator(allocatableRegs *RegisterInfo) Allocator {
a := Allocator{
regInfo: allocatableRegs,
blockLivenessDataPool: wazevoapi.NewPool[blockLivenessData](resetBlockLivenessData),
phiDefInstListPool: wazevoapi.NewPool[phiDefInstList](resetPhiDefInstList),
}
a.state.reset()
for _, regs := range allocatableRegs.AllocatableRegisters {
Expand Down Expand Up @@ -56,10 +57,10 @@ type (
blockLivenessData [] /* blockID to */ *blockLivenessData
vs []VReg
maxBlockID int
phiDefInstListPool wazevoapi.Pool[phiDefInstList]

// Followings are re-used during various places e.g. coloring.
blks []Block
insts []Instr
reals []RealReg
currentOccupants regInUseSet

Expand Down Expand Up @@ -115,11 +116,23 @@ type (
lastUse programCounter
// isPhi is true if this is a phi value.
isPhi bool
// isArg is true if this is phi (isPhi=true) and the value is passed via a real register at the beginning of the blk.
regPhi RealReg
// phiDefInstList is a list of instructions that defines this phi value.
// This is used to determine the spill location, and only valid if isPhi=true.
*phiDefInstList
}

// phiDefInstList is a linked list of instructions that defines a phi value.
phiDefInstList struct {
instr Instr
next *phiDefInstList
}
)

func resetPhiDefInstList(l *phiDefInstList) {
l.instr = nil
l.next = nil
}

func (s *state) dump(info *RegisterInfo) { //nolint:unused
fmt.Println("\t\tstate:")
fmt.Println("\t\t\targRealRegs:", s.argRealRegs)
Expand Down Expand Up @@ -178,7 +191,7 @@ func (vs *vrState) reset() {
vs.spilled = false
vs.lca = nil
vs.isPhi = false
vs.regPhi = RealRegInvalid
vs.phiDefInstList = nil
}

func (s *state) getVRegState(v VReg) *vrState {
Expand Down Expand Up @@ -531,7 +544,6 @@ func (a *Allocator) allocBlock(f Function, blk Block) {
}

pc = 0
a.insts = a.insts[:0]
for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
if wazevoapi.RegAllocLoggingEnabled {
fmt.Println(instr)
Expand Down Expand Up @@ -641,11 +653,10 @@ func (a *Allocator) allocBlock(f Function, blk Block) {
fmt.Printf("\tdefining v%d with %s\n", def.ID(), a.regInfo.RealRegName(r))
}
if vState.isPhi {
if blk.Entry() {
// If this is the entry block, the phi value has a unique definition.
vState.defInstr = instr
}
a.insts = append(a.insts, instr)
n := a.phiDefInstListPool.Allocate()
n.instr = instr
n.next = vState.phiDefInstList
vState.phiDefInstList = n
} else {
vState.defInstr = instr
vState.defBlk = blk
Expand All @@ -658,16 +669,6 @@ func (a *Allocator) allocBlock(f Function, blk Block) {
pc++
}

if !blk.Entry() {
for _, phiDefInstr := range a.insts {
phiDefInstr.Defs(&a.vs)
phi := a.vs[0]
if s.getVRegState(phi).r == RealRegInvalid {
f.StoreRegisterAfter(phi, phiDefInstr)
}
}
}

s.regsInUse.range_(func(allocated RealReg, v VReg) {
currentBlkState.endRegs.add(allocated, v)
})
Expand Down Expand Up @@ -723,14 +724,6 @@ func (a *Allocator) fixMergeState(f Function, blk Block) {
fmt.Println("fixMergeState", blk.ID(), ":", desiredOccupants.format(a.regInfo))
}

// Record that register-allocated phis and not.
for _, phi := range blk.BlockParams(&a.vs) {
vs := s.getVRegState(phi)
if r, ok := aliveOnRegVRegs[phi]; ok {
vs.regPhi = r
}
}

currentOccupants := &a.currentOccupants
for i := 0; i < preds; i++ {
currentOccupants.reset()
Expand Down Expand Up @@ -885,21 +878,11 @@ func (a *Allocator) scheduleSpills(f Function) {

func (a *Allocator) scheduleSpill(f Function, vs *vrState) {
v := vs.v
// If the value is the phi value, we need to insert a spill before the first instruction of the defining block whose
// arguments contain the value.
if phiDefiningBlk := vs.defBlk; vs.isPhi &&
// Except for the entry block since the phi value is actually defined via the instruction.
!phiDefiningBlk.Entry() {
if r := vs.regPhi; r == RealRegInvalid {
// This case, the phi is already passed via stack performed in the code inserted in fixMergeState function.
if wazevoapi.RegAllocLoggingEnabled {
fmt.Printf("v%d is already passed via stack at blk%v\n", v.ID(), phiDefiningBlk.ID())
fmt.Println(vs.defInstr)
a.blockStates[phiDefiningBlk.ID()].dump(a.regInfo)
}
} else {
// Otherwise, we need to insert a spill before the first instruction of the block.
f.StoreRegisterAfter(v.SetRealReg(r), phiDefiningBlk.FirstInstr())
// If the value is the phi value, we need to insert a spill after each phi definition.
if vs.isPhi {
for defInstr := vs.phiDefInstList; defInstr != nil; defInstr = defInstr.next {
def := defInstr.instr.Defs(&a.vs)[0]
f.StoreRegisterAfter(def, defInstr.instr)
}
return
}
Expand Down Expand Up @@ -955,6 +938,7 @@ func (a *Allocator) Reset() {
s.reset()
}
a.blockLivenessDataPool.Reset()
a.phiDefInstListPool.Reset()

a.vs = a.vs[:0]
a.maxBlockID = -1
Expand Down
18 changes: 17 additions & 1 deletion internal/integration_test/fuzzcases/fuzzcases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"fmt"
"math"
"runtime"
"testing"

Expand Down Expand Up @@ -719,7 +720,7 @@ func Test1825(t *testing.T) {
})
}

// Test1825 tests that lowerFcopysignImpl allocates correctly the temporary registers.
// Test1826 tests that lowerFcopysignImpl allocates correctly the temporary registers.
func Test1826(t *testing.T) {
if !platform.CompilerSupported() {
return
Expand All @@ -734,3 +735,18 @@ func Test1826(t *testing.T) {
require.Equal(t, uint64(0), m.Globals[0].ValHi)
})
}

func Test1846(t *testing.T) {
if !platform.CompilerSupported() {
return
}
run(t, func(t *testing.T, r wazero.Runtime) {
mod, err := r.Instantiate(ctx, getWasmBinary(t, "1846"))
require.NoError(t, err)
m := mod.(*wasm.ModuleInstance)
_, err = m.ExportedFunction("").Call(ctx)
require.NoError(t, err)
require.Equal(t, math.Float64bits(2), m.Globals[0].Val)
require.Equal(t, uint64(0), m.Globals[0].ValHi)
})
}
Binary file not shown.
105 changes: 105 additions & 0 deletions internal/integration_test/fuzzcases/testdata/1846.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
(module
(type (;0;) (func (result f64 f64 f32 f64 f64 f32 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64)))
(type (;1;) (func (result f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64)))
(type (;2;) (func))
(func (;0;) (type 1) (result f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64)
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0
f64.const 0.0

;; unreachable
)
(func (;1;) (type 2)
(local f64 f64 i32)
i32.const 0
if (type 0) (result f64 f64 f32 f64 f64 f32 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64) ;; label = @1
f64.const 1.0
call 0
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
f64.const 0x0p+0 (;=0;)
f32.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f32.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
else
f64.const 2.0
f64.const 0x0p+0 (;=0;)
f32.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f32.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const -0x1.a2d89e0481a8dp+83 (;=-15823560583422023000000000;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
f64.const 0x0p+0 (;=0;)
end
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
drop
i64.reinterpret_f64
global.set 0
)
(global (;0;) (mut i64) i64.const 0)
(global (;1;) (mut i32) i32.const 0)
(export "" (func 1))
)