Skip to content

Commit

Permalink
Merge pull request #3757 from onflow/bastian/compile-switch
Browse files Browse the repository at this point in the history
[Compiler] Compile switch
  • Loading branch information
turbolent authored Feb 5, 2025
2 parents 21dfd63 + 54a9df5 commit 17bfacc
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 5 deletions.
40 changes: 37 additions & 3 deletions bbq/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,43 @@ func (c *Compiler[_]) VisitEmitStatement(_ *ast.EmitStatement) (_ struct{}) {
panic(errors.NewUnreachableError())
}

func (c *Compiler[_]) VisitSwitchStatement(_ *ast.SwitchStatement) (_ struct{}) {
// TODO
panic(errors.NewUnreachableError())
func (c *Compiler[_]) VisitSwitchStatement(statement *ast.SwitchStatement) (_ struct{}) {
c.compileExpression(statement.Expression)
localIndex := c.currentFunction.generateLocalIndex()
c.codeGen.Emit(opcode.InstructionSetLocal{LocalIndex: localIndex})

endJumps := make([]int, 0, len(statement.Cases))
previousJump := -1

for _, switchCase := range statement.Cases {
if previousJump >= 0 {
c.patchJump(previousJump)
}

isDefault := switchCase.Expression == nil
if !isDefault {
c.codeGen.Emit(opcode.InstructionGetLocal{LocalIndex: localIndex})
c.compileExpression(switchCase.Expression)
c.codeGen.Emit(opcode.InstructionEqual{})
previousJump = c.emitUndefinedJumpIfFalse()

}

for _, caseStatement := range switchCase.Statements {
c.compileStatement(caseStatement)
}

if !isDefault {
endJump := c.emitUndefinedJump()
endJumps = append(endJumps, endJump)
}
}

for _, endJump := range endJumps {
c.patchJump(endJump)
}

return
}

func (c *Compiler[_]) VisitVariableDeclaration(declaration *ast.VariableDeclaration) (_ struct{}) {
Expand Down
103 changes: 103 additions & 0 deletions bbq/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,106 @@ func TestCompileIfLet(t *testing.T) {
program.Constants,
)
}

func TestCompileSwitch(t *testing.T) {

t.Parallel()

checker, err := ParseAndCheck(t, `
fun test(x: Int): Int {
var a = 0
switch x {
case 1:
a = 1
case 2:
a = 2
default:
a = 3
}
return a
}
`)
require.NoError(t, err)

compiler := NewInstructionCompiler(checker)
program := compiler.Compile()

require.Len(t, program.Functions, 1)

assert.Equal(t,
[]opcode.Instruction{
// var a = 0
opcode.InstructionGetConstant{ConstantIndex: 0x0},
opcode.InstructionTransfer{TypeIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x1},

// switch x
opcode.InstructionGetLocal{LocalIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x2},

// case 1:
opcode.InstructionGetLocal{LocalIndex: 0x2},
opcode.InstructionGetConstant{ConstantIndex: 0x1},
opcode.InstructionEqual{},
opcode.InstructionJumpIfFalse{Target: 13},

// a = 1
opcode.InstructionGetConstant{ConstantIndex: 0x1},
opcode.InstructionTransfer{TypeIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x1},

// jump to end
opcode.InstructionJump{Target: 24},

// case 2:
opcode.InstructionGetLocal{LocalIndex: 0x2},
opcode.InstructionGetConstant{ConstantIndex: 0x2},
opcode.InstructionEqual{},
opcode.InstructionJumpIfFalse{Target: 21},

// a = 2
opcode.InstructionGetConstant{ConstantIndex: 0x2},
opcode.InstructionTransfer{TypeIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x1},

// jump to end
opcode.InstructionJump{Target: 24},

// default:
// a = 3
opcode.InstructionGetConstant{ConstantIndex: 0x3},
opcode.InstructionTransfer{TypeIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x1},

// return a
opcode.InstructionGetLocal{LocalIndex: 0x1},
opcode.InstructionTransfer{TypeIndex: 0x0},
opcode.InstructionSetLocal{LocalIndex: 0x3},
opcode.InstructionGetLocal{LocalIndex: 0x3},
opcode.InstructionReturnValue{},
},
compiler.ExportFunctions()[0].Code,
)

assert.Equal(t,
[]*bbq.Constant{
{
Data: []byte{0x0},
Kind: constantkind.Int,
},
{
Data: []byte{0x1},
Kind: constantkind.Int,
},
{
Data: []byte{0x2},
Kind: constantkind.Int,
},
{
Data: []byte{0x3},
Kind: constantkind.Int,
},
},
program.Constants,
)
}
84 changes: 82 additions & 2 deletions bbq/vm/test/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3194,7 +3194,7 @@ func TestIfLet(t *testing.T) {
t.Parallel()

result, err := compileAndInvoke(t, `
fun main(x: Int?): Int {
fun main(x: Int?): Int {
if let y = x {
return y
} else {
Expand All @@ -3216,7 +3216,7 @@ func TestIfLet(t *testing.T) {
t.Parallel()

result, err := compileAndInvoke(t, `
fun main(x: Int?): Int {
fun main(x: Int?): Int {
if let y = x {
return y
} else {
Expand All @@ -3232,3 +3232,83 @@ func TestIfLet(t *testing.T) {
assert.Equal(t, vm.NewIntValue(2), result)
})
}

func TestCompileSwitch(t *testing.T) {

t.Parallel()

t.Run("1", func(t *testing.T) {
t.Parallel()

result, err := compileAndInvoke(t,
`
fun test(x: Int): Int {
var a = 0
switch x {
case 1:
a = a + 1
case 2:
a = a + 2
default:
a = a + 3
}
return a
}
`,
"test",
vm.NewIntValue(1),
)
require.NoError(t, err)
assert.Equal(t, vm.NewIntValue(1), result)
})

t.Run("2", func(t *testing.T) {
t.Parallel()

result, err := compileAndInvoke(t,
`
fun test(x: Int): Int {
var a = 0
switch x {
case 1:
a = a + 1
case 2:
a = a + 2
default:
a = a + 3
}
return a
}
`,
"test",
vm.NewIntValue(2),
)
require.NoError(t, err)
assert.Equal(t, vm.NewIntValue(2), result)
})

t.Run("4", func(t *testing.T) {
t.Parallel()

result, err := compileAndInvoke(t,
`
fun test(x: Int): Int {
var a = 0
switch x {
case 1:
a = a + 1
case 2:
a = a + 2
default:
a = a + 3
}
return a
}
`,
"test",
vm.NewIntValue(4),
)
require.NoError(t, err)
assert.Equal(t, vm.NewIntValue(3), result)
})
}

0 comments on commit 17bfacc

Please sign in to comment.