From f77e1a41f631c9328c5fba43625b4c1830caa4da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 10 Jan 2024 13:22:15 -0800 Subject: [PATCH 1/3] add support for dereferencing optional references --- runtime/interpreter/interpreter_expression.go | 25 ++++++++++++--- runtime/sema/check_unary_expression.go | 15 ++++++++- runtime/tests/checker/reference_test.go | 30 ++++++++++++++++++ runtime/tests/interpreter/reference_test.go | 31 +++++++++++++++++++ 4 files changed, 96 insertions(+), 5 deletions(-) diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index fc162a06c5..c8c931ec2b 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -697,16 +697,33 @@ func (interpreter *Interpreter) VisitUnaryExpression(expression *ast.UnaryExpres ) case ast.OperationMul: - referenceValue, ok := value.(ReferenceValue) - if !ok { - panic(errors.NewUnreachableError()) + + if _, ok := value.(NilValue); ok { + return Nil } + locationRange := LocationRange{ Location: interpreter.Location, HasPosition: expression, } + var isOptional bool + + if someValue, ok := value.(*SomeValue); ok { + isOptional = true + value = someValue.InnerValue(interpreter, locationRange) + } - return DereferenceValue(interpreter, locationRange, referenceValue) + referenceValue, ok := value.(ReferenceValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + dereferencedValue := DereferenceValue(interpreter, locationRange, referenceValue) + if isOptional { + return NewSomeValueNonCopying(interpreter, dereferencedValue) + } else { + return dereferencedValue + } case ast.OperationMove: interpreter.invalidateResource(value) diff --git a/runtime/sema/check_unary_expression.go b/runtime/sema/check_unary_expression.go index 1cf061a46e..0503a1d8db 100644 --- a/runtime/sema/check_unary_expression.go +++ b/runtime/sema/check_unary_expression.go @@ -62,6 +62,13 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty return checkExpectedType(valueType, SignedNumberType) case ast.OperationMul: + + var isOptional bool + if optionalType, ok := valueType.(*OptionalType); ok { + isOptional = true + valueType = optionalType.Type + } + referenceType, ok := valueType.(*ReferenceType) if !ok { if !valueType.IsInvalidType() { @@ -97,7 +104,13 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty ) } - return innerType + if isOptional { + return &OptionalType{ + Type: innerType, + } + } else { + return innerType + } case ast.OperationMove: if !valueType.IsInvalidType() && diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 2e3def0025..4801a7da6b 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3467,4 +3467,34 @@ func TestCheckDereference(t *testing.T) { `, ) }) + + t.Run("Optional", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "valid", + ` + let ref: &Int? = &1 as &Int + let y = *ref + `, + &sema.OptionalType{ + Type: sema.IntType, + }, + ) + + runInvalidTestCase( + t, + "invalid", + ` + struct S {} + + fun test() { + let s = S() + let ref: &S? = &s as &S + let deref = *ref + } + `, + ) + }) } diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 9ba4efff38..c8e8fbb8ff 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2931,4 +2931,35 @@ func TestInterpretDereference(t *testing.T) { interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) }) + + t.Run("Optional", func(t *testing.T) { + t.Parallel() + + runTestCase( + t, + "nil", + ` + fun main(): Int? { + let ref: &Int? = nil + return *ref + } + `, + interpreter.Nil, + ) + + runTestCase( + t, + "some", + ` + fun main(): Int? { + let ref: &Int? = &42 as &Int + return *ref + } + `, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewIntValueFromInt64(nil, 42), + ), + ) + }) + } From 0a9c47c20101ae4af2e8acb365f78f8f01919ba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 10 Jan 2024 13:22:30 -0800 Subject: [PATCH 2/3] constructor does not need interpreter, memory gauge is sufficient --- runtime/interpreter/value.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 1bc0d681ef..d812cf08cd 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19371,8 +19371,8 @@ type SomeValue struct { isDestroyed bool } -func NewSomeValueNonCopying(interpreter *Interpreter, value Value) *SomeValue { - common.UseMemory(interpreter, common.OptionalValueMemoryUsage) +func NewSomeValueNonCopying(memoryGauge common.MemoryGauge, value Value) *SomeValue { + common.UseMemory(memoryGauge, common.OptionalValueMemoryUsage) return NewUnmeteredSomeValueNonCopying(value) } From 673e9e389ebebdccc057c1c7ec385b1d83bc1c1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 10 Jan 2024 13:22:41 -0800 Subject: [PATCH 3/3] clean up --- runtime/tests/checker/reference_test.go | 186 ++++++++++---------- runtime/tests/interpreter/reference_test.go | 158 ++++++++--------- 2 files changed, 169 insertions(+), 175 deletions(-) diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 4801a7da6b..e89aab3513 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -1313,15 +1313,15 @@ func TestCheckInvalidDictionaryAccessOptionalReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: {String: S} = {} - let s = &dict[""] as &S? - let n = s.foo + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: {String: S} = {} + let s = &dict[""] as &S? + let n = s.foo `) errs := RequireCheckerErrors(t, err, 1) @@ -1334,14 +1334,14 @@ func TestCheckInvalidDictionaryAccessNonOptionalReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: {String: S} = {} - let s = &dict[""] as &S + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: {String: S} = {} + let s = &dict[""] as &S `) errs := RequireCheckerErrors(t, err, 1) @@ -1354,15 +1354,15 @@ func TestCheckArrayAccessReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: [S] = [] - let s = &dict[0] as &S - let n = s.foo + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: [S] = [] + let s = &dict[0] as &S + let n = s.foo `) require.NoError(t, err) @@ -2939,9 +2939,9 @@ func TestCheckResourceReferenceMethodInvocationAfterMove(t *testing.T) { // Moving the resource should update the tracking var newFoo <- foo - fooRef.id + fooRef.id - destroy newFoo + destroy newFoo } `) @@ -3171,9 +3171,9 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1 - let y: %[1]s = *x - `, + let x: &%[1]s = &1 + let y: %[1]s = *x + `, integerType, ), integerType, @@ -3189,9 +3189,9 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1.0 - let y: %[1]s = *x - `, + let x: &%[1]s = &1.0 + let y: %[1]s = *x + `, fixedPointType, ), fixedPointType, @@ -3233,10 +3233,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3314,10 +3314,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3330,28 +3330,28 @@ func TestCheckDereference(t *testing.T) { t, "[Struct]", ` - struct S{} - - fun test() { - let value: [S] = [S(), S()] - let x: &[S] = &value - let y: [S] = *x - } - `, + struct S{} + + fun test() { + let value: [S] = [S(), S()] + let x: &[S] = &value + let y: [S] = *x + } + `, ) runInvalidTestCase( t, "[Struct; 3]", ` - struct S{} - - fun test() { - let value: [S; 3] = [S(),S(),S()] - let x: &[S; 3] = &value - let y: [S; 3] = *x - } - `, + struct S{} + + fun test() { + let value: [S; 3] = [S(),S(),S()] + let x: &[S; 3] = &value + let y: [S; 3] = *x + } + `, ) }) @@ -3396,10 +3396,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3412,14 +3412,14 @@ func TestCheckDereference(t *testing.T) { t, "{Int: Struct}", ` - struct S{} - - fun test() { - let value: {Int: S} = { 1: S(), 2: S() } - let x: &{Int: S} = &value - let y: {Int: S} = *x - } - `, + struct S{} + + fun test() { + let value: {Int: S} = { 1: S(), 2: S() } + let x: &{Int: S} = &value + let y: {Int: S} = *x + } + `, ) }) @@ -3430,22 +3430,22 @@ func TestCheckDereference(t *testing.T) { t, "Resource", ` - resource interface I { - fun foo() - } - - resource R: I { - fun foo() {} - } - - fun test() { - let r <- create R() - let ref = &r as &{I} - let deref <- *ref - destroy r + resource interface I { + fun foo() + } + + resource R: I { + fun foo() {} + } + + fun test() { + let r <- create R() + let ref = &r as &{I} + let deref <- *ref + destroy r destroy deref - } - `, + } + `, ) }) @@ -3457,14 +3457,14 @@ func TestCheckDereference(t *testing.T) { t, "Struct", ` - struct S{} - - fun test() { - let s = S() - let ref = &s as &S - let deref = *ref - } - `, + struct S {} + + fun test() { + let s = S() + let ref = &s as &S + let deref = *ref + } + `, ) }) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index c8e8fbb8ff..3533c50cb0 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1636,14 +1636,14 @@ func TestInterpretReferenceTrackingOnInvocation(t *testing.T) { fooRef.something() // just to trick the checker - fooRef = returnSameRef(fooRef) + fooRef = returnSameRef(fooRef) // Moving the resource should update the tracking var newFoo <- foo - fooRef.id + fooRef.id - destroy newFoo + destroy newFoo } `) @@ -1820,12 +1820,14 @@ func TestInterpretReferenceToReference(t *testing.T) { func TestInterpretDereference(t *testing.T) { t.Parallel() - runValidTestCase := func( + runTestCase := func( t *testing.T, name, code string, expectedValue interpreter.Value, ) { t.Run(name, func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, code) value, err := inter.Invoke("main") @@ -1876,7 +1878,7 @@ func TestInterpretDereference(t *testing.T) { integerType := typ typString := typ.QualifiedString() - runValidTestCase( + runTestCase( t, typString, fmt.Sprintf( @@ -1885,7 +1887,7 @@ func TestInterpretDereference(t *testing.T) { let x: &%[1]s = &42 return *x } - `, + `, integerType, ), expectedValues[integerType], @@ -1911,7 +1913,7 @@ func TestInterpretDereference(t *testing.T) { fixedPointType := typ typString := typ.QualifiedString() - runValidTestCase( + runTestCase( t, typString, fmt.Sprintf( @@ -1920,7 +1922,7 @@ func TestInterpretDereference(t *testing.T) { let x: &%[1]s = &42.24 return *x } - `, + `, fixedPointType, ), expectedValues[fixedPointType], @@ -2747,12 +2749,12 @@ func TestInterpretDereference(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int: String} { - let original = {1: "ABC", 2: "DEF"} - let x: &{Int : String} = &original - return *x - } - `, + fun main(): {Int: String} { + let original = {1: "ABC", 2: "DEF"} + let x: &{Int : String} = &original + return *x + } + `, ) value, err := inter.Invoke("main") @@ -2781,12 +2783,12 @@ func TestInterpretDereference(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int: [String]} { - let original = {1: ["ABC", "XYZ"], 2: ["DEF"]} - let x: &{Int: [String]} = &original - return *x - } - `, + fun main(): {Int: [String]} { + let original = {1: ["ABC", "XYZ"], 2: ["DEF"]} + let x: &{Int: [String]} = &original + return *x + } + `, ) value, err := inter.Invoke("main") @@ -2834,16 +2836,16 @@ func TestInterpretDereference(t *testing.T) { t.Run("Character", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, "Character", ` - fun main(): Character { - let original: Character = "S" - let x: &Character = &original - return *x - } - `, + fun main(): Character { + let original: Character = "S" + let x: &Character = &original + return *x + } + `, interpreter.NewUnmeteredCharacterValue("S"), ) }) @@ -2851,83 +2853,75 @@ func TestInterpretDereference(t *testing.T) { t.Run("String", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, "String", ` - fun main(): String { - let original: String = "STxy" - let x: &String = &original - return *x - } - `, + fun main(): String { + let original: String = "STxy" + let x: &String = &original + return *x + } + `, interpreter.NewUnmeteredStringValue("STxy"), ) }) - t.Run("Bool", func(t *testing.T) { - t.Parallel() - - runValidTestCase( - t, - "Bool", - ` - fun main(): Bool { - let original: Bool = true - let x: &Bool = &original - return *x - } - `, - interpreter.BoolValue(true), - ) - }) - - t.Run("Address", func(t *testing.T) { - t.Parallel() + runTestCase( + t, + "Bool", + ` + fun main(): Bool { + let original: Bool = true + let x: &Bool = &original + return *x + } + `, + interpreter.BoolValue(true), + ) - address, err := common.HexToAddress("0x0000000000000231") - assert.NoError(t, err) + address, err := common.HexToAddress("0x0000000000000231") + assert.NoError(t, err) - runValidTestCase( - t, - "Address", - ` - fun main(): Address { - let original: Address = 0x0000000000000231 - let x: &Address = &original - return *x - } - `, - interpreter.NewAddressValue(nil, address), - ) - }) + runTestCase( + t, + "Address", + ` + fun main(): Address { + let original: Address = 0x0000000000000231 + let x: &Address = &original + return *x + } + `, + interpreter.NewAddressValue(nil, address), + ) t.Run("Path", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, "PrivatePath", ` - fun main(): Path { - let original: Path = /private/temp - let x: &Path = &original - return *x - } - `, + fun main(): Path { + let original: Path = /private/temp + let x: &Path = &original + return *x + } + `, interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), ) - runValidTestCase( + runTestCase( t, "PublicPath", ` - fun main(): Path { - let original: Path = /public/temp - let x: &Path = &original - return *x - } - `, + fun main(): Path { + let original: Path = /public/temp + let x: &Path = &original + return *x + } + `, interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) })