diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index fc162a06c5..c8c931ec2b 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -697,16 +697,33 @@ func (interpreter *Interpreter) VisitUnaryExpression(expression *ast.UnaryExpres ) case ast.OperationMul: - referenceValue, ok := value.(ReferenceValue) - if !ok { - panic(errors.NewUnreachableError()) + + if _, ok := value.(NilValue); ok { + return Nil } + locationRange := LocationRange{ Location: interpreter.Location, HasPosition: expression, } + var isOptional bool + + if someValue, ok := value.(*SomeValue); ok { + isOptional = true + value = someValue.InnerValue(interpreter, locationRange) + } - return DereferenceValue(interpreter, locationRange, referenceValue) + referenceValue, ok := value.(ReferenceValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + dereferencedValue := DereferenceValue(interpreter, locationRange, referenceValue) + if isOptional { + return NewSomeValueNonCopying(interpreter, dereferencedValue) + } else { + return dereferencedValue + } case ast.OperationMove: interpreter.invalidateResource(value) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index acb95d0040..42a548c337 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19432,8 +19432,8 @@ type SomeValue struct { isDestroyed bool } -func NewSomeValueNonCopying(interpreter *Interpreter, value Value) *SomeValue { - common.UseMemory(interpreter, common.OptionalValueMemoryUsage) +func NewSomeValueNonCopying(memoryGauge common.MemoryGauge, value Value) *SomeValue { + common.UseMemory(memoryGauge, common.OptionalValueMemoryUsage) return NewUnmeteredSomeValueNonCopying(value) } diff --git a/runtime/sema/check_unary_expression.go b/runtime/sema/check_unary_expression.go index 1cf061a46e..0503a1d8db 100644 --- a/runtime/sema/check_unary_expression.go +++ b/runtime/sema/check_unary_expression.go @@ -62,6 +62,13 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty return checkExpectedType(valueType, SignedNumberType) case ast.OperationMul: + + var isOptional bool + if optionalType, ok := valueType.(*OptionalType); ok { + isOptional = true + valueType = optionalType.Type + } + referenceType, ok := valueType.(*ReferenceType) if !ok { if !valueType.IsInvalidType() { @@ -97,7 +104,13 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty ) } - return innerType + if isOptional { + return &OptionalType{ + Type: innerType, + } + } else { + return innerType + } case ast.OperationMove: if !valueType.IsInvalidType() && diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 5e7ec301a6..a3852e7266 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -1313,15 +1313,15 @@ func TestCheckInvalidDictionaryAccessOptionalReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: {String: S} = {} - let s = &dict[""] as &S? - let n = s.foo + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: {String: S} = {} + let s = &dict[""] as &S? + let n = s.foo `) errs := RequireCheckerErrors(t, err, 1) @@ -1334,14 +1334,14 @@ func TestCheckInvalidDictionaryAccessNonOptionalReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: {String: S} = {} - let s = &dict[""] as &S + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: {String: S} = {} + let s = &dict[""] as &S `) errs := RequireCheckerErrors(t, err, 1) @@ -1354,15 +1354,15 @@ func TestCheckArrayAccessReference(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - access(all) struct S { - access(all) let foo: Number - init() { - self.foo = 0 - } - } - let dict: [S] = [] - let s = &dict[0] as &S - let n = s.foo + access(all) struct S { + access(all) let foo: Number + init() { + self.foo = 0 + } + } + let dict: [S] = [] + let s = &dict[0] as &S + let n = s.foo `) require.NoError(t, err) @@ -2939,9 +2939,9 @@ func TestCheckResourceReferenceMethodInvocationAfterMove(t *testing.T) { // Moving the resource should update the tracking var newFoo <- foo - fooRef.id + fooRef.id - destroy newFoo + destroy newFoo } `) @@ -3171,9 +3171,9 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1 - let y: %[1]s = *x - `, + let x: &%[1]s = &1 + let y: %[1]s = *x + `, integerType, ), integerType, @@ -3189,9 +3189,9 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1.0 - let y: %[1]s = *x - `, + let x: &%[1]s = &1.0 + let y: %[1]s = *x + `, fixedPointType, ), fixedPointType, @@ -3233,10 +3233,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3314,10 +3314,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3330,28 +3330,28 @@ func TestCheckDereference(t *testing.T) { t, "[Struct]", ` - struct S{} - - fun test() { - let value: [S] = [S(), S()] - let x: &[S] = &value - let y: [S] = *x - } - `, + struct S{} + + fun test() { + let value: [S] = [S(), S()] + let x: &[S] = &value + let y: [S] = *x + } + `, ) runInvalidTestCase( t, "[Struct; 3]", ` - struct S{} - - fun test() { - let value: [S; 3] = [S(),S(),S()] - let x: &[S; 3] = &value - let y: [S; 3] = *x - } - `, + struct S{} + + fun test() { + let value: [S; 3] = [S(),S(),S()] + let x: &[S; 3] = &value + let y: [S; 3] = *x + } + `, ) }) @@ -3396,10 +3396,10 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = *x + `, testCase.ty, testCase.initializer, ), @@ -3412,14 +3412,14 @@ func TestCheckDereference(t *testing.T) { t, "{Int: Struct}", ` - struct S{} - - fun test() { - let value: {Int: S} = { 1: S(), 2: S() } - let x: &{Int: S} = &value - let y: {Int: S} = *x - } - `, + struct S{} + + fun test() { + let value: {Int: S} = { 1: S(), 2: S() } + let x: &{Int: S} = &value + let y: {Int: S} = *x + } + `, ) }) @@ -3427,36 +3427,36 @@ func TestCheckDereference(t *testing.T) { t, "Resource", ` - resource interface I { - fun foo() - } + resource interface I { + fun foo() + } - resource R: I { - fun foo() {} - } + resource R: I { + fun foo() {} + } - fun test() { - let r <- create R() - let ref = &r as &{I} - let deref <- *ref - destroy r - destroy deref - } - `, + fun test() { + let r <- create R() + let ref = &r as &{I} + let deref <- *ref + destroy r + destroy deref + } + `, ) runInvalidTestCase( t, "Struct", ` - struct S{} + struct S{} - fun test() { - let s = S() - let ref = &s as &S - let deref = *ref - } - `, + fun test() { + let s = S() + let ref = &s as &S + let deref = *ref + } + `, ) t.Run("built-in", func(t *testing.T) { @@ -3467,10 +3467,40 @@ func TestCheckDereference(t *testing.T) { t, "Account", ` - fun test(ref: &Account): Account { - return *ref - } - `, + fun test(ref: &Account): Account { + return *ref + } + `, + ) + }) + + t.Run("Optional", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "valid", + ` + let ref: &Int? = &1 as &Int + let y = *ref + `, + &sema.OptionalType{ + Type: sema.IntType, + }, + ) + + runInvalidTestCase( + t, + "invalid", + ` + struct S {} + + fun test() { + let s = S() + let ref: &S? = &s as &S + let deref = *ref + } + `, ) }) } diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 9ba4efff38..3533c50cb0 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1636,14 +1636,14 @@ func TestInterpretReferenceTrackingOnInvocation(t *testing.T) { fooRef.something() // just to trick the checker - fooRef = returnSameRef(fooRef) + fooRef = returnSameRef(fooRef) // Moving the resource should update the tracking var newFoo <- foo - fooRef.id + fooRef.id - destroy newFoo + destroy newFoo } `) @@ -1820,12 +1820,14 @@ func TestInterpretReferenceToReference(t *testing.T) { func TestInterpretDereference(t *testing.T) { t.Parallel() - runValidTestCase := func( + runTestCase := func( t *testing.T, name, code string, expectedValue interpreter.Value, ) { t.Run(name, func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, code) value, err := inter.Invoke("main") @@ -1876,7 +1878,7 @@ func TestInterpretDereference(t *testing.T) { integerType := typ typString := typ.QualifiedString() - runValidTestCase( + runTestCase( t, typString, fmt.Sprintf( @@ -1885,7 +1887,7 @@ func TestInterpretDereference(t *testing.T) { let x: &%[1]s = &42 return *x } - `, + `, integerType, ), expectedValues[integerType], @@ -1911,7 +1913,7 @@ func TestInterpretDereference(t *testing.T) { fixedPointType := typ typString := typ.QualifiedString() - runValidTestCase( + runTestCase( t, typString, fmt.Sprintf( @@ -1920,7 +1922,7 @@ func TestInterpretDereference(t *testing.T) { let x: &%[1]s = &42.24 return *x } - `, + `, fixedPointType, ), expectedValues[fixedPointType], @@ -2747,12 +2749,12 @@ func TestInterpretDereference(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int: String} { - let original = {1: "ABC", 2: "DEF"} - let x: &{Int : String} = &original - return *x - } - `, + fun main(): {Int: String} { + let original = {1: "ABC", 2: "DEF"} + let x: &{Int : String} = &original + return *x + } + `, ) value, err := inter.Invoke("main") @@ -2781,12 +2783,12 @@ func TestInterpretDereference(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int: [String]} { - let original = {1: ["ABC", "XYZ"], 2: ["DEF"]} - let x: &{Int: [String]} = &original - return *x - } - `, + fun main(): {Int: [String]} { + let original = {1: ["ABC", "XYZ"], 2: ["DEF"]} + let x: &{Int: [String]} = &original + return *x + } + `, ) value, err := inter.Invoke("main") @@ -2834,16 +2836,16 @@ func TestInterpretDereference(t *testing.T) { t.Run("Character", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, "Character", ` - fun main(): Character { - let original: Character = "S" - let x: &Character = &original - return *x - } - `, + fun main(): Character { + let original: Character = "S" + let x: &Character = &original + return *x + } + `, interpreter.NewUnmeteredCharacterValue("S"), ) }) @@ -2851,84 +2853,107 @@ func TestInterpretDereference(t *testing.T) { t.Run("String", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, "String", ` - fun main(): String { - let original: String = "STxy" - let x: &String = &original - return *x - } - `, + fun main(): String { + let original: String = "STxy" + let x: &String = &original + return *x + } + `, interpreter.NewUnmeteredStringValue("STxy"), ) }) - t.Run("Bool", func(t *testing.T) { + runTestCase( + t, + "Bool", + ` + fun main(): Bool { + let original: Bool = true + let x: &Bool = &original + return *x + } + `, + interpreter.BoolValue(true), + ) + + address, err := common.HexToAddress("0x0000000000000231") + assert.NoError(t, err) + + runTestCase( + t, + "Address", + ` + fun main(): Address { + let original: Address = 0x0000000000000231 + let x: &Address = &original + return *x + } + `, + interpreter.NewAddressValue(nil, address), + ) + + t.Run("Path", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, - "Bool", + "PrivatePath", ` - fun main(): Bool { - let original: Bool = true - let x: &Bool = &original - return *x - } - `, - interpreter.BoolValue(true), + fun main(): Path { + let original: Path = /private/temp + let x: &Path = &original + return *x + } + `, + interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), ) - }) - - t.Run("Address", func(t *testing.T) { - t.Parallel() - - address, err := common.HexToAddress("0x0000000000000231") - assert.NoError(t, err) - runValidTestCase( + runTestCase( t, - "Address", + "PublicPath", ` - fun main(): Address { - let original: Address = 0x0000000000000231 - let x: &Address = &original - return *x - } - `, - interpreter.NewAddressValue(nil, address), + fun main(): Path { + let original: Path = /public/temp + let x: &Path = &original + return *x + } + `, + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) }) - t.Run("Path", func(t *testing.T) { + t.Run("Optional", func(t *testing.T) { t.Parallel() - runValidTestCase( + runTestCase( t, - "PrivatePath", + "nil", ` - fun main(): Path { - let original: Path = /private/temp - let x: &Path = &original - return *x - } - `, - interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), + fun main(): Int? { + let ref: &Int? = nil + return *ref + } + `, + interpreter.Nil, ) - runValidTestCase( + runTestCase( t, - "PublicPath", + "some", ` - fun main(): Path { - let original: Path = /public/temp - let x: &Path = &original - return *x - } - `, - interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), + fun main(): Int? { + let ref: &Int? = &42 as &Int + return *ref + } + `, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewIntValueFromInt64(nil, 42), + ), ) }) + }