diff --git a/bbq/compiler/desugar.go b/bbq/compiler/desugar.go index df30a35dc..d37ff49a8 100644 --- a/bbq/compiler/desugar.go +++ b/bbq/compiler/desugar.go @@ -130,6 +130,11 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) // Add the remaining statements that are defined in this function. statements := funcBlock.Block.Statements modifiedStatements = append(modifiedStatements, statements...) + + } else if d.enclosingInterfaceType != nil { + // If this is an interface-method without a body, + // then do not generate a function for it. + return nil } // Before the post conditions are appended, we need to move the @@ -186,6 +191,7 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) nil, ) + // TODO: Is the generated function needed to be desugared again? return ast.NewFunctionDeclaration( d.memoryGauge, declaration.Access, @@ -656,6 +662,8 @@ func (d *Desugar) generateConditionsFunction( "", ) + // TODO: Is the generated function needed to be desugared? + d.modifiedDeclarations = append(d.modifiedDeclarations, conditionFunc) } @@ -836,8 +844,6 @@ func (d *Desugar) VisitAttachmentDeclaration(declaration *ast.AttachmentDeclarat } func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaration) ast.Declaration { - existingMembers := declaration.Members.Declarations() - compositeType := d.elaboration.CompositeDeclarationType(declaration) // Recursively de-sugar nested declarations (functions, types, etc.) @@ -850,17 +856,26 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio var desugaredMembers []ast.Declaration membersDesugared := false + existingMembers := declaration.Members.Declarations() for _, member := range existingMembers { desugaredMember := d.desugarDeclaration(member) + if desugaredMember == nil { + continue + } membersDesugared = membersDesugared || (desugaredMember != member) desugaredMembers = append(desugaredMembers, desugaredMember) } - // Copy over inherited default functions. - - inheritedDefaultFuncs := d.inheritedDefaultFunctions(compositeType, declaration) + // Add inherited default functions. + existingFunctions := declaration.Members.FunctionsByIdentifier() + inheritedDefaultFuncs := d.inheritedDefaultFunctions( + compositeType, + existingFunctions, + declaration.StartPos, + declaration.Range, + ) // Optimization: If none of the existing members got updated or, // if there are no inherited members, then return the same declaration as-is. @@ -868,10 +883,7 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio return declaration } - modifiedMembers := make([]ast.Declaration, len(desugaredMembers)) - copy(modifiedMembers, desugaredMembers) - - modifiedMembers = append(modifiedMembers, inheritedDefaultFuncs...) + desugaredMembers = append(desugaredMembers, inheritedDefaultFuncs...) modifiedDecl := ast.NewCompositeDeclaration( d.memoryGauge, @@ -879,7 +891,7 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio declaration.CompositeKind, declaration.Identifier, declaration.Conformances, - ast.NewMembers(d.memoryGauge, modifiedMembers), + ast.NewMembers(d.memoryGauge, desugaredMembers), declaration.DocString, declaration.Range, ) @@ -890,11 +902,10 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio return modifiedDecl } -func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.CompositeType) map[string][]*inheritedFunction { +func (d *Desugar) inheritedFunctionsWithConditions(compositeType sema.ConformingType) map[string][]*inheritedFunction { inheritedFunctions := make(map[string][]*inheritedFunction) - for _, conformance := range compositeType.EffectiveInterfaceConformances() { - interfaceType := conformance.InterfaceType + compositeType.EffectiveInterfaceConformanceSet().ForEach(func(interfaceType *sema.InterfaceType) { elaboration, err := d.config.ElaborationResolver(interfaceType.Location) if err != nil { @@ -915,44 +926,24 @@ func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.Composite }) inheritedFunctions[name] = funcs } - } + }) return inheritedFunctions } -func (d *Desugar) inheritedDefaultFunctions(compositeType *sema.CompositeType, decl *ast.CompositeDeclaration) []ast.Declaration { - directMembers := compositeType.Members - allMembers := compositeType.GetMembers() +func (d *Desugar) inheritedDefaultFunctions( + compositeType sema.ConformingType, + existingFunctions map[string]*ast.FunctionDeclaration, + pos ast.Position, + declRange ast.Range, +) []ast.Declaration { - pos := decl.StartPos + inheritedDefaultFunctions := make(map[string]struct{}) inheritedMembers := make([]ast.Declaration, 0) - for memberName, resolver := range allMembers { // nolint:maprange - if directMembers.Contains(memberName) { - continue - } - - member := resolver.Resolve( - d.memoryGauge, - memberName, - ast.EmptyRange, - func(err error) { - if err != nil { - panic(err) - } - }, - ) - - // Only interested in functions. - // Also filter out built-in functions. - if member.DeclarationKind != common.DeclarationKindFunction || - member.Predeclared { - continue - } - - // Inherited functions are always from interfaces - interfaceType := member.ContainerType.(*sema.InterfaceType) + for _, conformance := range compositeType.EffectiveInterfaceConformances() { + interfaceType := conformance.InterfaceType elaboration, err := d.config.ElaborationResolver(interfaceType.Location) if err != nil { @@ -960,71 +951,129 @@ func (d *Desugar) inheritedDefaultFunctions(compositeType *sema.CompositeType, d } interfaceDecl := elaboration.InterfaceTypeDeclaration(interfaceType) - functions := interfaceDecl.Members.FunctionsByIdentifier() - inheritedFunc, ok := functions[memberName] - if !ok { - panic(errors.NewUnreachableError()) - } - // for each inherited function, generate a delegator function, - // which calls the actual default implementation at the interface. - // i.e: - // FooImpl { - // fun defaultFunc(a1: T1, a2: T2): R { - // return FooInterface.defaultFunc(a1, a2) - // } - // } + for funcName, inheritedFunc := range functions { // nolint:maprange + if !inheritedFunc.FunctionBlock.HasStatements() { + continue + } + + // Pick the 'closest' default function. + // This is the same way how it is implemented in the interpreter. + _, ok := inheritedDefaultFunctions[funcName] + if ok { + continue + } + inheritedDefaultFunctions[funcName] = struct{}{} + + // If the inherited function is overridden by the current type, then skip. + if d.isFunctionOverridden(compositeType, funcName, existingFunctions) { + continue + } - // Generate: `FooInterface.defaultFunc(a1, a2)` + // For each inherited function, generate a delegator function, + // which calls the actual default implementation at the interface. + // i.e: + // FooImpl { + // fun defaultFunc(a1: T1, a2: T2): R { + // return FooInterface.defaultFunc(a1, a2) + // } + // } - inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc) + // Generate: `FooInterface.defaultFunc(a1, a2)` - invocation := d.interfaceDelegationMethodCall( - interfaceType, - inheritedFuncType, - pos, - memberName, - member, - ) + inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc) - // Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }` - defaultFuncDelegator := ast.NewFunctionDeclaration( - d.memoryGauge, - inheritedFunc.Access, - inheritedFunc.Purity, - inheritedFunc.IsStatic(), - inheritedFunc.IsNative(), - ast.NewIdentifier( - d.memoryGauge, - memberName, + member, ok := interfaceType.MemberMap().Get(funcName) + if !ok { + panic(errors.NewUnreachableError()) + } + + invocation := d.interfaceDelegationMethodCall( + interfaceType, + inheritedFuncType, pos, - ), - inheritedFunc.TypeParameterList, - inheritedFunc.ParameterList, - inheritedFunc.ReturnTypeAnnotation, - ast.NewFunctionBlock( + funcName, + member, + ) + + funcReturnType := inheritedFuncType.ReturnTypeAnnotation.Type + returnStmt := ast.NewReturnStatement(d.memoryGauge, invocation, declRange) + d.elaboration.SetReturnStatementTypes( + returnStmt, + sema.ReturnStatementTypes{ + ValueType: funcReturnType, + ReturnType: funcReturnType, + }, + ) + + // Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }` + defaultFuncDelegator := ast.NewFunctionDeclaration( d.memoryGauge, - ast.NewBlock( + inheritedFunc.Access, + inheritedFunc.Purity, + inheritedFunc.IsStatic(), + inheritedFunc.IsNative(), + ast.NewIdentifier( d.memoryGauge, - []ast.Statement{ - ast.NewReturnStatement(d.memoryGauge, invocation, decl.Range), - }, - decl.Range, + funcName, + pos, ), - nil, - nil, - ), - inheritedFunc.StartPos, - inheritedFunc.DocString, - ) + inheritedFunc.TypeParameterList, + inheritedFunc.ParameterList, + inheritedFunc.ReturnTypeAnnotation, + ast.NewFunctionBlock( + d.memoryGauge, + ast.NewBlock( + d.memoryGauge, + []ast.Statement{ + returnStmt, + }, + declRange, + ), + nil, + nil, + ), + inheritedFunc.StartPos, + inheritedFunc.DocString, + ) + + d.elaboration.SetFunctionDeclarationFunctionType(defaultFuncDelegator, inheritedFuncType) - inheritedMembers = append(inheritedMembers, defaultFuncDelegator) + // Pass the generated default function again through the desugar phase, + // so that it will properly link/chain the function conditions + // that are inherited/available for this default function. + desugaredDelegator := d.desugarDeclaration(defaultFuncDelegator) + + inheritedMembers = append(inheritedMembers, desugaredDelegator) + + } } return inheritedMembers } +func (d *Desugar) isFunctionOverridden( + enclosingType sema.ConformingType, + funcName string, + existingFunctions map[string]*ast.FunctionDeclaration, +) bool { + implementedFunc, isImplemented := existingFunctions[funcName] + if !isImplemented { + return false + } + + _, isInterface := enclosingType.(*sema.InterfaceType) + if isInterface { + // If the currently visiting declaration is an interface type (i.e: This function is an interface method) + // then it is considered as a default implementation only if there are statements. + // This is because interface methods can define conditions, without overriding the function. + return implementedFunc.FunctionBlock.HasStatements() + } + + return true +} + func (d *Desugar) interfaceDelegationMethodCall( interfaceType *sema.InterfaceType, inheritedFuncType *sema.FunctionType, @@ -1162,6 +1211,7 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio prevModifiedDecls := d.modifiedDeclarations prevEnclosingInterfaceType := d.enclosingInterfaceType + d.modifiedDeclarations = nil d.enclosingInterfaceType = interfaceType @@ -1170,26 +1220,30 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio d.enclosingInterfaceType = prevEnclosingInterfaceType }() - existingMembers := declaration.Members.Declarations() - // Recursively de-sugar nested declarations (functions, types, etc.) - membersDesugared := false - + existingMembers := declaration.Members.Declarations() for _, member := range existingMembers { desugaredMember := d.desugarDeclaration(member) - membersDesugared = membersDesugared || (desugaredMember != member) + if desugaredMember == nil { + continue + } d.modifiedDeclarations = append(d.modifiedDeclarations, desugaredMember) } - // Optimization: If none of the existing members got updated or, + // Add inherited default functions. + existingFunctions := declaration.Members.FunctionsByIdentifier() + inheritedDefaultFuncs := d.inheritedDefaultFunctions( + interfaceType, + existingFunctions, + declaration.StartPos, + declaration.Range, + ) + + d.modifiedDeclarations = append(d.modifiedDeclarations, inheritedDefaultFuncs...) + + // TODO: Optimize: If none of the existing members got updated or, // if there are no inherited members, then return the same declaration as-is. - //if !membersDesugared && len(inheritedDefaultFuncs) == 0 { - // return declaration - //} - if !membersDesugared { - return declaration - } modifiedDecl := ast.NewInterfaceDeclaration( d.memoryGauge, diff --git a/bbq/compiler/extended_elaboration.go b/bbq/compiler/extended_elaboration.go index ff9e276dd..3e6949a2a 100644 --- a/bbq/compiler/extended_elaboration.go +++ b/bbq/compiler/extended_elaboration.go @@ -37,6 +37,8 @@ type ExtendedElaboration struct { assignmentStatementTypes map[*ast.AssignmentStatement]sema.AssignmentStatementTypes resultVariableTypes map[ast.Element]sema.Type referenceExpressionBorrowTypes map[*ast.ReferenceExpression]sema.Type + functionDeclarationFunctionTypes map[*ast.FunctionDeclaration]*sema.FunctionType + returnStatementTypes map[*ast.ReturnStatement]sema.ReturnStatementTypes } func NewExtendedElaboration(elaboration *sema.Elaboration) *ExtendedElaboration { @@ -83,9 +85,23 @@ func (e *ExtendedElaboration) InterfaceDeclarationType(decl *ast.InterfaceDeclar } func (e *ExtendedElaboration) ReturnStatementTypes(statement *ast.ReturnStatement) sema.ReturnStatementTypes { + if e.returnStatementTypes != nil { + typ, ok := e.returnStatementTypes[statement] + if ok { + return typ + } + } + return e.elaboration.ReturnStatementTypes(statement) } +func (e *ExtendedElaboration) SetReturnStatementTypes(statement *ast.ReturnStatement, types sema.ReturnStatementTypes) { + if e.returnStatementTypes == nil { + e.returnStatementTypes = map[*ast.ReturnStatement]sema.ReturnStatementTypes{} + } + e.returnStatementTypes[statement] = types +} + func (e *ExtendedElaboration) VariableDeclarationTypes(declaration *ast.VariableDeclaration) sema.VariableDeclarationTypes { if e.variableDeclarationTypes != nil { typ, ok := e.variableDeclarationTypes[declaration] @@ -210,10 +226,6 @@ func (e *ExtendedElaboration) CastingExpressionTypes(expression *ast.CastingExpr return e.elaboration.CastingExpressionTypes(expression) } -func (e *ExtendedElaboration) FunctionDeclarationFunctionType(declaration *ast.FunctionDeclaration) *sema.FunctionType { - return e.elaboration.FunctionDeclarationFunctionType(declaration) -} - func (e *ExtendedElaboration) ResultVariableType(enclosingBlock ast.Element) (typ sema.Type, exist bool) { if e.resultVariableTypes != nil { types, ok := e.resultVariableTypes[enclosingBlock] @@ -247,3 +259,23 @@ func (e *ExtendedElaboration) SetReferenceExpressionBorrowType(expression *ast.R } e.referenceExpressionBorrowTypes[expression] = ty } + +func (e *ExtendedElaboration) FunctionDeclarationFunctionType(declaration *ast.FunctionDeclaration) *sema.FunctionType { + if e.functionDeclarationFunctionTypes != nil { + typ, ok := e.functionDeclarationFunctionTypes[declaration] + if ok { + return typ + } + } + return e.elaboration.FunctionDeclarationFunctionType(declaration) +} + +func (e *ExtendedElaboration) SetFunctionDeclarationFunctionType( + declaration *ast.FunctionDeclaration, + functionType *sema.FunctionType, +) { + if e.functionDeclarationFunctionTypes == nil { + e.functionDeclarationFunctionTypes = map[*ast.FunctionDeclaration]*sema.FunctionType{} + } + e.functionDeclarationFunctionTypes[declaration] = functionType +} diff --git a/bbq/vm/test/utils.go b/bbq/vm/test/utils.go index 83d41e27a..9ed03fdc5 100644 --- a/bbq/vm/test/utils.go +++ b/bbq/vm/test/utils.go @@ -440,7 +440,7 @@ func parseCheckAndCompileCodeWithOptions( options.ParseAndCheckOptions, programs, ) - programs[location] = &compiledProgram{ + programs[checker.Location] = &compiledProgram{ Elaboration: checker.Elaboration, } @@ -450,7 +450,7 @@ func parseCheckAndCompileCodeWithOptions( checker, programs, ) - programs[location].Program = program + programs[checker.Location].Program = program return program } diff --git a/bbq/vm/test/vm_test.go b/bbq/vm/test/vm_test.go index cae6c7acf..e1ced33f1 100644 --- a/bbq/vm/test/vm_test.go +++ b/bbq/vm/test/vm_test.go @@ -3272,3 +3272,120 @@ func TestFunctionPostConditions(t *testing.T) { }) } + +func TestDefaultFunctionsWithConditions(t *testing.T) { + + t.Parallel() + + t.Run("default in parent, conditions in child", func(t *testing.T) { + t.Parallel() + + storage := interpreter.NewInMemoryStorage(nil) + + activation := sema.NewVariableActivation(sema.BaseValueActivation) + activation.DeclareValue(stdlib.PanicFunction) + activation.DeclareValue(stdlib.NewStandardLibraryStaticFunction( + "log", + sema.NewSimpleFunctionType( + sema.FunctionPurityView, + []sema.Parameter{ + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "value", + TypeAnnotation: sema.AnyStructTypeAnnotation, + }, + }, + sema.VoidTypeAnnotation, + ), + "", + nil, + )) + + var logs []string + vmConfig := &vm.Config{ + Storage: storage, + AccountHandler: &testAccountHandler{}, + //ImportHandler: func(location common.Location) *bbq.Program[opcode.Instruction] { + // program, ok := programs[location] + // if !ok { + // assert.FailNow(t, "invalid location") + // } + // return program.Program + //}, + //ContractValueHandler: func(_ *vm.Config, location common.Location) *vm.CompositeValue { + // contractValue, ok := contractValues[location] + // if !ok { + // assert.FailNow(t, "invalid location") + // } + // return contractValue + //}, + + NativeFunctionsProvider: func() map[string]vm.Value { + funcs := vm.NativeFunctions() + funcs[commons.LogFunctionName] = vm.NativeFunctionValue{ + ParameterCount: len(stdlib.LogFunctionType.Parameters), + Function: func(config *vm.Config, typeArguments []interpreter.StaticType, arguments ...vm.Value) vm.Value { + logs = append(logs, arguments[0].String()) + return vm.VoidValue{} + }, + } + + return funcs + }, + } + + _, err := compileAndInvokeWithOptions(t, ` + struct interface Foo { + fun test(_ a: Int) { + printMessage("invoked Foo.test()") + } + } + + struct interface Bar: Foo { + fun test(_ a: Int) { + pre { + printMessage("invoked Bar.test() pre-condition") + } + + post { + printMessage("invoked Bar.test() post-condition") + } + } + } + + struct Test: Bar {} + + access(all) view fun printMessage(_ msg: String): Bool { + log(msg) + return true + } + + fun main() { + Test().test(5) + } + `, + "main", + CompilerAndVMOptions{ + VMConfig: vmConfig, + ParseAndCheckOptions: &ParseAndCheckOptions{ + Config: &sema.Config{ + LocationHandler: singleIdentifierLocationResolver(t), + BaseValueActivationHandler: func(location common.Location) *sema.VariableActivation { + return activation + }, + }, + }, + }, + ) + + require.NoError(t, err) + require.Equal( + t, + []string{ + "invoked Bar.test() pre-condition", + "invoked Foo.test()", + "invoked Bar.test() post-condition", + }, logs, + ) + }) +} diff --git a/interpreter/function_test.go b/interpreter/function_test.go index b16b77c64..3134a3445 100644 --- a/interpreter/function_test.go +++ b/interpreter/function_test.go @@ -323,3 +323,36 @@ func TestInterpretFunctionSubtyping(t *testing.T) { result, ) } + +func TestInterpretDefaultFunctionWithConditions(t *testing.T) { + + t.Parallel() + + inter, getLogs, err := parseCheckAndInterpretWithLogs(t, ` + struct interface Foo { + fun test(_ a: Int) { + log("Calling to Foo.test") + } + } + + struct interface Bar: Foo { + fun test(_ a: Int) { + pre { + a > 10: "a must be greater than 10" + } + } + } + + struct Test: Bar {} + + fun main() { + Test().test(12) + }`, + ) + + _, err = inter.Invoke("main") + require.NoError(t, err) + + logs := getLogs() + require.Empty(t, logs) +} diff --git a/sema/type.go b/sema/type.go index 3b930f8d2..cda7bed58 100644 --- a/sema/type.go +++ b/sema/type.go @@ -316,6 +316,7 @@ func TypeActivationNestedType(typeActivation *VariableActivation, qualifiedIdent type ConformingType interface { Type EffectiveInterfaceConformanceSet() *InterfaceSet + EffectiveInterfaceConformances() []Conformance } // CompositeKindedType is a type which has a composite kind