diff --git a/internal/engine/wazevo/backend/regalloc/regalloc.go b/internal/engine/wazevo/backend/regalloc/regalloc.go index ce9c6d7e38..bbaa5c6f9b 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc.go @@ -56,8 +56,8 @@ type ( allocatableSet RegSet allocatedCalleeSavedRegs []VReg vs []VReg - vs2 []VRegID - is []Instr + ss []*vrState + copies []_copy phiDefInstListPool wazevoapi.Pool[phiDefInstList] // Followings are re-used during various places. @@ -69,6 +69,12 @@ type ( blockStates wazevoapi.IDedPool[blockState] } + // _copy represents a source and destination pair of a copy instruction. + _copy struct { + src *vrState + dstID VRegID + } + // programCounter represents an opaque index into the program which is used to represents a LiveInterval of a VReg. programCounter int32 @@ -85,7 +91,7 @@ type ( blockState struct { // liveIns is a list of VReg that are live at the beginning of the block. - liveIns []VRegID + liveIns []*vrState // seen is true if the block is visited during the liveness analysis. seen bool // visited is true if the block is visited during the allocation phase. @@ -193,13 +199,6 @@ func (s *state) reset() { s.currentBlockID = -1 } -func (s *state) setVRegState(v VReg, r RealReg) { - id := int(v.ID()) - st := s.vrStates.GetOrAllocate(id) - st.r = r - st.v = v -} - func resetVrState(vs *vrState) { vs.v = VRegInvalid vs.r = RealRegInvalid @@ -214,24 +213,32 @@ func resetVrState(vs *vrState) { vs.desiredLoc = desiredLocUnspecified } +func (s *state) getOrAllocateVRegState(v VReg) *vrState { + st := s.vrStates.GetOrAllocate(int(v.ID())) + if st.v == VRegInvalid { + st.v = v + } + return st +} + func (s *state) getVRegState(v VRegID) *vrState { - return s.vrStates.GetOrAllocate(int(v)) + return s.vrStates.Get(int(v)) } -func (s *state) useRealReg(r RealReg, v VReg) { +func (s *state) useRealReg(r RealReg, vr *vrState) { if s.regsInUse.has(r) { panic("BUG: useRealReg: the given real register is already used") } - s.regsInUse.add(r, v) - s.setVRegState(v, r) + s.regsInUse.add(r, vr) + vr.r = r s.allocatedRegSet = s.allocatedRegSet.add(r) } func (s *state) releaseRealReg(r RealReg) { current := s.regsInUse.get(r) - if current.Valid() { + if current != nil { s.regsInUse.remove(r) - s.setVRegState(current, RealRegInvalid) + current.r = RealRegInvalid } } @@ -276,7 +283,7 @@ func (s *state) findOrSpillAllocatable(a *Allocator, allocatable []RealReg, forb } using := s.regsInUse.get(candidateReal) - if using == VRegInvalid { + if using == nil { // This is not used at this point. return candidateReal } @@ -285,17 +292,17 @@ func (s *state) findOrSpillAllocatable(a *Allocator, allocatable []RealReg, forb // For example, if the register is used as an argument register, and it might be // spilled and not reloaded when it ends up being used as a temporary to pass // stack based argument. - if using.IsRealReg() { + if using.v.IsRealReg() { continue } isPreferred := candidateReal == preferred // last == -1 means the value won't be used anymore. - if last := s.getVRegState(using.ID()).lastUse; r == RealRegInvalid || isPreferred || last == -1 || (lastUseAt != -1 && last > lastUseAt) { + if last := using.lastUse; r == RealRegInvalid || isPreferred || last == -1 || (lastUseAt != -1 && last > lastUseAt) { lastUseAt = last r = candidateReal - spillVReg = using + spillVReg = using.v if isPreferred { break } @@ -323,16 +330,14 @@ func (s *state) findAllocatable(allocatable []RealReg, forbiddenMask RegSet) Rea } func (s *state) resetAt(bs *blockState) { - s.regsInUse.range_(func(_ RealReg, vr VReg) { - s.setVRegState(vr, RealRegInvalid) + s.regsInUse.range_(func(_ RealReg, vs *vrState) { + vs.r = RealRegInvalid }) s.regsInUse.reset() - bs.endRegs.range_(func(r RealReg, v VReg) { - id := int(v.ID()) - st := s.vrStates.GetOrAllocate(id) - if st.lastUseUpdatedAtBlockID == s.currentBlockID && st.lastUse == programCounterLiveIn { - s.regsInUse.add(r, v) - s.setVRegState(v, r) + bs.endRegs.range_(func(r RealReg, vs *vrState) { + if vs.lastUseUpdatedAtBlockID == s.currentBlockID && vs.lastUse == programCounterLiveIn { + s.regsInUse.add(r, vs) + vs.r = r } }) } @@ -376,8 +381,7 @@ func (a *Allocator) getOrAllocateBlockState(blockID int32) *blockState { } // phiBlk returns the block that defines the given phi value, nil otherwise. -func (s *state) phiBlk(v VRegID) Block { - vs := s.getVRegState(v) +func (vs *vrState) phiBlk() Block { if vs.isPhi { return vs.defBlk } @@ -394,10 +398,14 @@ const ( func (a *Allocator) livenessAnalysis(f Function) { s := &a.state + for i := VRegID(0); i < vRegIDReservedForRealNum; i++ { + s.getOrAllocateVRegState(VReg(i).SetRealReg(RealReg(i))) + } + for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() { // We should gather phi value data. for _, p := range blk.BlockParams(&a.vs) { - vs := s.getVRegState(p.ID()) + vs := s.getOrAllocateVRegState(p) vs.isPhi = true vs.defBlk = blk } @@ -405,7 +413,7 @@ func (a *Allocator) livenessAnalysis(f Function) { blkID := blk.ID() info := a.getOrAllocateBlockState(blkID) - a.vs2 = a.vs2[:0] + a.ss = a.ss[:0] const ( flagDeleted = false flagLive = true @@ -423,12 +431,11 @@ func (a *Allocator) livenessAnalysis(f Function) { continue } - for _, v := range succInfo.liveIns { - if s.phiBlk(v) != succ { - st := s.getVRegState(v) + for _, st := range succInfo.liveIns { + if st.phiBlk() != succ { // We use .spilled field to store the flag. st.spilled = flagLive - a.vs2 = append(a.vs2, v) + a.ss = append(a.ss, st) } } } @@ -436,26 +443,26 @@ func (a *Allocator) livenessAnalysis(f Function) { for instr := blk.InstrRevIteratorBegin(); instr != nil; instr = blk.InstrRevIteratorNext() { var use, def VReg + var defIsPhi bool for _, def = range instr.Defs(&a.vs) { if !def.IsRealReg() { - id := def.ID() - st := s.getVRegState(id) + st := s.getOrAllocateVRegState(def) + defIsPhi = st.isPhi // We use .spilled field to store the flag. st.spilled = flagDeleted - a.vs2 = append(a.vs2, id) + a.ss = append(a.ss, st) } } for _, use = range instr.Uses(&a.vs) { if !use.IsRealReg() { - id := use.ID() - st := s.getVRegState(id) + st := s.getOrAllocateVRegState(use) // We use .spilled field to store the flag. st.spilled = flagLive - a.vs2 = append(a.vs2, id) + a.ss = append(a.ss, st) } } - if def.Valid() && s.phiBlk(def.ID()) != nil { + if defIsPhi { if use.Valid() && use.IsRealReg() { // If the destination is a phi value, and the source is a real register, this is the beginning of the function. a.state.argRealRegs = append(a.state.argRealRegs, use) @@ -463,11 +470,10 @@ func (a *Allocator) livenessAnalysis(f Function) { } } - for _, v := range a.vs2 { - st := s.getVRegState(v) + for _, st := range a.ss { // We use .spilled field to store the flag. if st.spilled == flagLive { //nolint:gosimple - info.liveIns = append(info.liveIns, v) + info.liveIns = append(info.liveIns, st) st.spilled = false } } @@ -487,27 +493,25 @@ func (a *Allocator) loopTreeDFS(entry Block) { a.blks = a.blks[:0] a.blks = append(a.blks, entry) - s := &a.state for len(a.blks) > 0 { tail := len(a.blks) - 1 loop := a.blks[tail] a.blks = a.blks[:tail] - a.vs2 = a.vs2[:0] + a.ss = a.ss[:0] const ( flagDone = false flagPending = true ) info := a.getOrAllocateBlockState(loop.ID()) - for _, v := range info.liveIns { - if s.phiBlk(v) != loop { - a.vs2 = append(a.vs2, v) - st := s.getVRegState(v) + for _, st := range info.liveIns { + if st.phiBlk() != loop { + a.ss = append(a.ss, st) // We use .spilled field to store the flag. st.spilled = flagPending } } - var siblingAddedView []VRegID + var siblingAddedView []*vrState cn := loop.LoopNestingForestChildren() for i := 0; i < cn; i++ { child := loop.LoopNestingForestChild(i) @@ -516,13 +520,12 @@ func (a *Allocator) loopTreeDFS(entry Block) { if i == 0 { begin := len(childInfo.liveIns) - for _, v := range a.vs2 { - st := s.getVRegState(v) + for _, st := range a.ss { // We use .spilled field to store the flag. if st.spilled == flagPending { //nolint:gosimple st.spilled = flagDone // TODO: deduplicate, though I don't think it has much impact. - childInfo.liveIns = append(childInfo.liveIns, v) + childInfo.liveIns = append(childInfo.liveIns, st) } } siblingAddedView = childInfo.liveIns[begin:] @@ -538,8 +541,7 @@ func (a *Allocator) loopTreeDFS(entry Block) { if cn == 0 { // If there's no forest child, we haven't cleared the .spilled field at this point. - for _, v := range a.vs2 { - st := s.getVRegState(v) + for _, st := range a.ss { st.spilled = false } } @@ -577,8 +579,7 @@ func (a *Allocator) alloc(f Function) { func (a *Allocator) updateLiveInVRState(liveness *blockState) { currentBlockID := a.state.currentBlockID - for _, v := range liveness.liveIns { - vs := a.state.getVRegState(v) + for _, vs := range liveness.liveIns { vs.lastUse = programCounterLiveIn vs.lastUseUpdatedAtBlockID = currentBlockID } @@ -619,7 +620,7 @@ func (a *Allocator) finalizeStartReg(blk Block) { panic(fmt.Sprintf("BUG: at lease one predecessor should be visited for blk%d", blk.ID())) } for _, u := range s.argRealRegs { - s.useRealReg(u.RealReg(), u) + s.useRealReg(u.RealReg(), s.getVRegState(u.ID())) } currentBlkState.startFromPredIndex = 0 } else { @@ -630,7 +631,7 @@ func (a *Allocator) finalizeStartReg(blk Block) { s.resetAt(predState) } - s.regsInUse.range_(func(allocated RealReg, v VReg) { + s.regsInUse.range_(func(allocated RealReg, v *vrState) { currentBlkState.startRegs.add(allocated, v) }) if wazevoapi.RegAllocLoggingEnabled { @@ -649,38 +650,33 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } // Clears the previous state. - s.regsInUse.range_(func(allocatedRealReg RealReg, vr VReg) { - s.setVRegState(vr, RealRegInvalid) - }) + s.regsInUse.range_(func(allocatedRealReg RealReg, vr *vrState) { vr.r = RealRegInvalid }) s.regsInUse.reset() // Then set the start state. - currentBlkState.startRegs.range_(func(allocatedRealReg RealReg, vr VReg) { - s.useRealReg(allocatedRealReg, vr) - }) + currentBlkState.startRegs.range_(func(allocatedRealReg RealReg, vr *vrState) { s.useRealReg(allocatedRealReg, vr) }) - desiredUpdated := a.vs2[:0] + desiredUpdated := a.ss[:0] // Update the last use of each VReg. - a.is = a.is[:0] // Stores the copy instructions. + a.copies = a.copies[:0] // Stores the copy instructions. var pc programCounter for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() { - var use, def VReg - for _, use = range instr.Uses(&a.vs) { + var useState *vrState + for _, use := range instr.Uses(&a.vs) { + useState = s.getVRegState(use.ID()) if !use.IsRealReg() { - s.getVRegState(use.ID()).lastUse = pc + useState.lastUse = pc } } if instr.IsCopy() { - a.is = append(a.is, instr) - def = instr.Defs(&a.vs)[0] + def := instr.Defs(&a.vs)[0] + a.copies = append(a.copies, _copy{src: useState, dstID: def.ID()}) r := def.RealReg() if r != RealRegInvalid { - useID := use.ID() - vs := s.getVRegState(useID) - if !vs.isPhi { // TODO: no idea why do we need this. - vs.desiredLoc = newDesiredLocReg(r) - desiredUpdated = append(desiredUpdated, useID) + if !useState.isPhi { // TODO: no idea why do we need this. + useState.desiredLoc = newDesiredLocReg(r) + desiredUpdated = append(desiredUpdated, useState) } } } @@ -698,9 +694,8 @@ func (a *Allocator) allocBlock(f Function, blk Block) { succID := succ.ID() succState := a.getOrAllocateBlockState(succID) - for _, v := range succState.liveIns { - if s.phiBlk(v) != succ { - st := s.getVRegState(v) + for _, st := range succState.liveIns { + if st.phiBlk() != succ { st.lastUse = programCounterLiveOut } } @@ -709,36 +704,28 @@ func (a *Allocator) allocBlock(f Function, blk Block) { if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("blk%d -> blk%d: start_regs: %s\n", bID, succID, succState.startRegs.format(a.regInfo)) } - succState.startRegs.range_(func(allocatedRealReg RealReg, vr VReg) { - vs := s.getVRegState(vr.ID()) + succState.startRegs.range_(func(allocatedRealReg RealReg, vs *vrState) { vs.desiredLoc = newDesiredLocReg(allocatedRealReg) - desiredUpdated = append(desiredUpdated, vr.ID()) + desiredUpdated = append(desiredUpdated, vs) }) for _, p := range succ.BlockParams(&a.vs) { vs := s.getVRegState(p.ID()) if vs.desiredLoc.realReg() == RealRegInvalid { vs.desiredLoc = desiredLocStack - desiredUpdated = append(desiredUpdated, p.ID()) + desiredUpdated = append(desiredUpdated, vs) } } } } // Propagate the desired register values from the end of the block to the beginning. - for _, instr := range a.is { - def := instr.Defs(&a.vs)[0] - defState := s.getVRegState(def.ID()) + for _, instr := range a.copies { + defState := s.getVRegState(instr.dstID) desired := defState.desiredLoc.realReg() - if desired == RealRegInvalid { - continue - } - - use := instr.Uses(&a.vs)[0] - useID := use.ID() - useState := s.getVRegState(useID) - if s.phiBlk(useID) != succ && useState.desiredLoc == desiredLocUnspecified { + useState := instr.src + if useState.phiBlk() != succ && useState.desiredLoc == desiredLocUnspecified { useState.desiredLoc = newDesiredLocReg(desired) - desiredUpdated = append(desiredUpdated, useID) + desiredUpdated = append(desiredUpdated, useState) } } @@ -780,7 +767,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { vs.desiredLoc.realReg()) vs.recordReload(f, blk) f.ReloadRegisterBefore(use.SetRealReg(r), instr) - s.useRealReg(r, use) + s.useRealReg(r, vs) } if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("\ttrying to use v%v on %s\n", use.ID(), a.regInfo.RealRegName(r)) @@ -827,21 +814,21 @@ func (a *Allocator) allocBlock(f Function, blk Block) { if s.regsInUse.has(r) { s.releaseRealReg(r) } - s.useRealReg(r, def) + s.useRealReg(r, s.getVRegState(def.ID())) } case 0: case 1: def := defs[0] + vState := s.getVRegState(def.ID()) if def.IsRealReg() { r := def.RealReg() if a.allocatableSet.has(r) { if s.regsInUse.has(r) { s.releaseRealReg(r) } - s.useRealReg(r, def) + s.useRealReg(r, vState) } } else { - vState := s.getVRegState(def.ID()) r := vState.r if desired := vState.desiredLoc.realReg(); desired != RealRegInvalid { @@ -862,7 +849,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } r = desired s.releaseRealReg(r) - s.useRealReg(r, def) + s.useRealReg(r, vState) } } } @@ -880,7 +867,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { typ := def.RegType() r = s.findOrSpillAllocatable(a, a.regInfo.AllocatableRegisters[typ], RegSet(0), RealRegInvalid) } - s.useRealReg(r, def) + s.useRealReg(r, vState) } dr := def.SetRealReg(r) instr.AssignDef(dr) @@ -913,9 +900,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { pc++ } - s.regsInUse.range_(func(allocated RealReg, v VReg) { - currentBlkState.endRegs.add(allocated, v) - }) + s.regsInUse.range_(func(allocated RealReg, v *vrState) { currentBlkState.endRegs.add(allocated, v) }) currentBlkState.visited = true if wazevoapi.RegAllocLoggingEnabled { @@ -923,11 +908,10 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } // Reset the desired end location. - for _, v := range desiredUpdated { - vs := s.getVRegState(v) + for _, vs := range desiredUpdated { vs.desiredLoc = desiredLocUnspecified } - a.vs2 = desiredUpdated[:0] + a.ss = desiredUpdated[:0] for i := 0; i < blk.Succs(); i++ { succ := blk.Succ(i) @@ -946,8 +930,8 @@ func (a *Allocator) releaseCallerSavedRegs(addrReg RealReg) { if allocated == addrReg { // If this is the call indirect, we should not touch the addr register. continue } - if v := s.regsInUse.get(allocated); v.Valid() { - if v.IsRealReg() { + if vs := s.regsInUse.get(allocated); vs != nil { + if vs.v.IsRealReg() { continue // This is the argument register as it's already used by VReg backed by the corresponding RealReg. } if !a.regInfo.CallerSavedRegisters.has(allocated) { @@ -973,7 +957,7 @@ func (a *Allocator) fixMergeState(f Function, blk Block) { desiredOccupants := &blkSt.startRegs var desiredOccupantsSet RegSet for i, v := range desiredOccupants { - if v != VRegInvalid { + if v != nil { desiredOccupantsSet = desiredOccupantsSet.add(RealReg(i)) } } @@ -1010,16 +994,16 @@ func (a *Allocator) fixMergeState(f Function, blk Block) { for r := RealReg(0); r < 64; r++ { desiredVReg := desiredOccupants.get(r) - if !desiredVReg.Valid() { + if desiredVReg == nil { continue } currentVReg := s.regsInUse.get(r) - if desiredVReg.ID() == currentVReg.ID() { + if currentVReg != nil && desiredVReg.v.ID() == currentVReg.v.ID() { continue } - typ := desiredVReg.RegType() + typ := desiredVReg.v.RegType() var tmpRealReg VReg if typ == RegTypeInt { tmpRealReg = intTmp @@ -1040,10 +1024,15 @@ func (a *Allocator) fixMergeState(f Function, blk Block) { func (a *Allocator) reconcileEdge(f Function, r RealReg, pred Block, - currentVReg, desiredVReg VReg, + currentState, desiredState *vrState, freeReg VReg, typ RegType, ) { + desiredVReg := desiredState.v + currentVReg := VRegInvalid + if currentState != nil { + currentVReg = currentState.v + } // There are four cases to consider: // 1. currentVReg is valid, but desiredVReg is on the stack. // 2. Both currentVReg and desiredVReg are valid. @@ -1052,7 +1041,6 @@ func (a *Allocator) reconcileEdge(f Function, s := &a.state if currentVReg.Valid() { - desiredState := s.getVRegState(desiredVReg.ID()) er := desiredState.r if er == RealRegInvalid { // Case 1: currentVReg is valid, but desiredVReg is on the stack. @@ -1066,9 +1054,9 @@ func (a *Allocator) reconcileEdge(f Function, f.StoreRegisterBefore(currentVReg.SetRealReg(r), pred.LastInstrForInsertion()) s.releaseRealReg(r) - s.getVRegState(desiredVReg.ID()).recordReload(f, pred) + desiredState.recordReload(f, pred) f.ReloadRegisterBefore(desiredVReg.SetRealReg(r), pred.LastInstrForInsertion()) - s.useRealReg(r, desiredVReg) + s.useRealReg(r, desiredState) return } else { // Case 2: Both currentVReg and desiredVReg are valid. @@ -1087,8 +1075,8 @@ func (a *Allocator) reconcileEdge(f Function, s.allocatedRegSet = s.allocatedRegSet.add(freeReg.RealReg()) s.releaseRealReg(r) s.releaseRealReg(er) - s.useRealReg(r, desiredVReg) - s.useRealReg(er, currentVReg) + s.useRealReg(r, desiredState) + s.useRealReg(er, currentState) if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("\t\tv%d previously on %s moved to %s\n", currentVReg.ID(), a.regInfo.RealRegName(r), a.regInfo.RealRegName(er)) } @@ -1099,7 +1087,7 @@ func (a *Allocator) reconcileEdge(f Function, desiredVReg.ID(), a.regInfo.RealRegName(r), ) } - if currentReg := s.getVRegState(desiredVReg.ID()).r; currentReg != RealRegInvalid { + if currentReg := desiredState.r; currentReg != RealRegInvalid { // Case 3: Desired is on a different register than `r` and currentReg is not valid. // We simply need to move the desired value to the register. f.InsertMoveBefore( @@ -1111,10 +1099,10 @@ func (a *Allocator) reconcileEdge(f Function, } else { // Case 4: Both currentVReg and desiredVReg are not valid. // We simply need to reload the desired value into the register. - s.getVRegState(desiredVReg.ID()).recordReload(f, pred) + desiredState.recordReload(f, pred) f.ReloadRegisterBefore(desiredVReg.SetRealReg(r), pred.LastInstrForInsertion()) } - s.useRealReg(r, desiredVReg) + s.useRealReg(r, desiredState) } } @@ -1157,7 +1145,7 @@ func (a *Allocator) scheduleSpill(f Function, vs *vrState) { for pos != definingBlk { st := a.getOrAllocateBlockState(pos.ID()) for rr := RealReg(0); rr < 64; rr++ { - if st.startRegs.get(rr) == v { + if vs := st.startRegs.get(rr); vs != nil && vs.v == v { r = rr // Already in the register, so we can place the spill at the beginning of the block. break diff --git a/internal/engine/wazevo/backend/regalloc/regalloc_test.go b/internal/engine/wazevo/backend/regalloc/regalloc_test.go index 526d27928c..818b59b953 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc_test.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc_test.go @@ -13,10 +13,15 @@ func TestAllocator_livenessAnalysis(t *testing.T) { const realRegID, realRegID2 = 50, 100 realReg, realReg2 := FromRealReg(realRegID, RegTypeInt), FromRealReg(realRegID2, RegTypeInt) phiVReg := VReg(12345).SetRegType(RegTypeInt) + + type exp struct { + liveIns []VRegID + } + for _, tc := range []struct { name string setup func() Function - exp map[int]*blockState + exps map[int]*exp }{ { name: "single block", @@ -28,7 +33,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { ).entry(), ) }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, }, }, @@ -46,7 +51,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { blk.blockParam(param) return newMockFunction(blk) }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, }, }, @@ -71,7 +76,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { b1.addPred(b0) return newMockFunction(b0, b1, b2) }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, 1: { liveIns: []VRegID{3}, @@ -112,7 +117,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { b2.addPred(b0) return newMockFunction(b0, b1, b2, b3) }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, 1: {liveIns: []VRegID{1000, 1}}, 2: { @@ -154,7 +159,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { b1.addPred(b0) return newMockFunction(b0, b1, b2, b3, b4) }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, 1: { liveIns: []VRegID{2000, 3000}, @@ -212,7 +217,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { f.loopNestingForestRoots(b1) return f }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: { liveIns: []VRegID{}, }, @@ -253,7 +258,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { f.loopNestingForestRoots(b2) return f }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: {}, 1: { liveIns: []VRegID{9999}, @@ -303,7 +308,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { f.loopNestingForestRoots(b1) return f }, - exp: map[int]*blockState{ + exps: map[int]*exp{ 0: { liveIns: []VRegID{111}, }, @@ -335,17 +340,21 @@ func TestAllocator_livenessAnalysis(t *testing.T) { continue } t.Run(fmt.Sprintf("block_id=%d", blockID), func(t *testing.T) { - exp := tc.exp[blockID] + exp := tc.exps[blockID] if len(exp.liveIns) == 0 { require.Nil(t, actual.liveIns, "live ins") } else { - sort.Slice(actual.liveIns, func(i, j int) bool { - return actual.liveIns[i] < actual.liveIns[j] + var actuals []VRegID + for _, s := range actual.liveIns { + actuals = append(actuals, s.v.ID()) + } + sort.Slice(actuals, func(i, j int) bool { + return actuals[i] < actuals[j] }) sort.Slice(exp.liveIns, func(i, j int) bool { return exp.liveIns[i] < exp.liveIns[j] }) - require.Equal(t, exp.liveIns, actual.liveIns, "live ins") + require.Equal(t, exp.liveIns, actuals, "live ins") } }) } @@ -367,13 +376,13 @@ func TestAllocator_livenessAnalysis_copy(t *testing.T) { func Test_findOrSpillAllocatable_prefersSpill(t *testing.T) { t.Run("ok", func(t *testing.T) { s := &state{regsInUse: newRegInUseSet()} - s.regsInUse.add(RealReg(1), VReg(2222222)) + s.regsInUse.add(RealReg(1), &vrState{v: VReg(2222222)}) got := s.findOrSpillAllocatable(&Allocator{}, []RealReg{3}, 0, 3) require.Equal(t, RealReg(3), got) }) t.Run("preferred but in use", func(t *testing.T) { s := &state{vrStates: wazevoapi.NewIDedPool[vrState](resetVrState)} - s.regsInUse.add(RealReg(3), VReg(1).SetRealReg(3)) + s.regsInUse.add(RealReg(3), &vrState{v: VReg(1).SetRealReg(3)}) got := s.findOrSpillAllocatable(&Allocator{}, []RealReg{3, 4}, 0, 3) require.Equal(t, RealReg(4), got) }) diff --git a/internal/engine/wazevo/backend/regalloc/regset.go b/internal/engine/wazevo/backend/regalloc/regset.go index 04a8e8f4db..1624fc96f9 100644 --- a/internal/engine/wazevo/backend/regalloc/regset.go +++ b/internal/engine/wazevo/backend/regalloc/regset.go @@ -46,7 +46,7 @@ func (rs RegSet) Range(f func(allocatedRealReg RealReg)) { } } -type regInUseSet [64]VReg +type regInUseSet [64]*vrState func newRegInUseSet() regInUseSet { var ret regInUseSet @@ -56,42 +56,42 @@ func newRegInUseSet() regInUseSet { func (rs *regInUseSet) reset() { for i := range rs { - rs[i] = VRegInvalid + rs[i] = nil } } func (rs *regInUseSet) format(info *RegisterInfo) string { //nolint:unused var ret []string for i, vr := range rs { - if vr != VRegInvalid { - ret = append(ret, fmt.Sprintf("(%s->v%d)", info.RealRegName(RealReg(i)), vr.ID())) + if vr != nil { + ret = append(ret, fmt.Sprintf("(%s->v%d)", info.RealRegName(RealReg(i)), vr.v.ID())) } } return strings.Join(ret, ", ") } func (rs *regInUseSet) has(r RealReg) bool { - return r < 64 && rs[r] != VRegInvalid + return r < 64 && rs[r] != nil } -func (rs *regInUseSet) get(r RealReg) VReg { +func (rs *regInUseSet) get(r RealReg) *vrState { return rs[r] } func (rs *regInUseSet) remove(r RealReg) { - rs[r] = VRegInvalid + rs[r] = nil } -func (rs *regInUseSet) add(r RealReg, vr VReg) { +func (rs *regInUseSet) add(r RealReg, vr *vrState) { if r >= 64 { return } rs[r] = vr } -func (rs *regInUseSet) range_(f func(allocatedRealReg RealReg, vr VReg)) { +func (rs *regInUseSet) range_(f func(allocatedRealReg RealReg, vr *vrState)) { for i, vr := range rs { - if vr != VRegInvalid { + if vr != nil { f(RealReg(i), vr) } }