Skip to content

Commit

Permalink
Improve default functions co-existence with conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Feb 3, 2025
1 parent 1885788 commit 0ee77c8
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 110 deletions.
262 changes: 158 additions & 104 deletions bbq/compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration)
// Add the remaining statements that are defined in this function.
statements := funcBlock.Block.Statements
modifiedStatements = append(modifiedStatements, statements...)

} else if d.enclosingInterfaceType != nil {
// If this is an interface-method without a body,
// then do not generate a function for it.
return nil
}

// Before the post conditions are appended, we need to move the
Expand Down Expand Up @@ -186,6 +191,7 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration)
nil,
)

// TODO: Is the generated function needed to be desugared again?
return ast.NewFunctionDeclaration(
d.memoryGauge,
declaration.Access,
Expand Down Expand Up @@ -656,6 +662,8 @@ func (d *Desugar) generateConditionsFunction(
"",
)

// TODO: Is the generated function needed to be desugared?

d.modifiedDeclarations = append(d.modifiedDeclarations, conditionFunc)
}

Expand Down Expand Up @@ -836,8 +844,6 @@ func (d *Desugar) VisitAttachmentDeclaration(declaration *ast.AttachmentDeclarat
}

func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaration) ast.Declaration {
existingMembers := declaration.Members.Declarations()

compositeType := d.elaboration.CompositeDeclarationType(declaration)

// Recursively de-sugar nested declarations (functions, types, etc.)
Expand All @@ -850,36 +856,42 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio

var desugaredMembers []ast.Declaration
membersDesugared := false
existingMembers := declaration.Members.Declarations()

for _, member := range existingMembers {
desugaredMember := d.desugarDeclaration(member)
if desugaredMember == nil {
continue
}

membersDesugared = membersDesugared || (desugaredMember != member)
desugaredMembers = append(desugaredMembers, desugaredMember)
}

// Copy over inherited default functions.

inheritedDefaultFuncs := d.inheritedDefaultFunctions(compositeType, declaration)
// Add inherited default functions.
existingFunctions := declaration.Members.FunctionsByIdentifier()
inheritedDefaultFuncs := d.inheritedDefaultFunctions(
compositeType,
existingFunctions,
declaration.StartPos,
declaration.Range,
)

// Optimization: If none of the existing members got updated or,
// if there are no inherited members, then return the same declaration as-is.
if !membersDesugared && len(inheritedDefaultFuncs) == 0 {
return declaration
}

modifiedMembers := make([]ast.Declaration, len(desugaredMembers))
copy(modifiedMembers, desugaredMembers)

modifiedMembers = append(modifiedMembers, inheritedDefaultFuncs...)
desugaredMembers = append(desugaredMembers, inheritedDefaultFuncs...)

modifiedDecl := ast.NewCompositeDeclaration(
d.memoryGauge,
declaration.Access,
declaration.CompositeKind,
declaration.Identifier,
declaration.Conformances,
ast.NewMembers(d.memoryGauge, modifiedMembers),
ast.NewMembers(d.memoryGauge, desugaredMembers),
declaration.DocString,
declaration.Range,
)
Expand All @@ -890,11 +902,10 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio
return modifiedDecl
}

func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.CompositeType) map[string][]*inheritedFunction {
func (d *Desugar) inheritedFunctionsWithConditions(compositeType sema.ConformingType) map[string][]*inheritedFunction {
inheritedFunctions := make(map[string][]*inheritedFunction)

for _, conformance := range compositeType.EffectiveInterfaceConformances() {
interfaceType := conformance.InterfaceType
compositeType.EffectiveInterfaceConformanceSet().ForEach(func(interfaceType *sema.InterfaceType) {

elaboration, err := d.config.ElaborationResolver(interfaceType.Location)
if err != nil {
Expand All @@ -915,116 +926,154 @@ func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.Composite
})
inheritedFunctions[name] = funcs
}
}
})

return inheritedFunctions
}

func (d *Desugar) inheritedDefaultFunctions(compositeType *sema.CompositeType, decl *ast.CompositeDeclaration) []ast.Declaration {
directMembers := compositeType.Members
allMembers := compositeType.GetMembers()
func (d *Desugar) inheritedDefaultFunctions(
compositeType sema.ConformingType,
existingFunctions map[string]*ast.FunctionDeclaration,
pos ast.Position,
declRange ast.Range,
) []ast.Declaration {

pos := decl.StartPos
inheritedDefaultFunctions := make(map[string]struct{})

inheritedMembers := make([]ast.Declaration, 0)

for memberName, resolver := range allMembers { // nolint:maprange
if directMembers.Contains(memberName) {
continue
}

member := resolver.Resolve(
d.memoryGauge,
memberName,
ast.EmptyRange,
func(err error) {
if err != nil {
panic(err)
}
},
)

// Only interested in functions.
// Also filter out built-in functions.
if member.DeclarationKind != common.DeclarationKindFunction ||
member.Predeclared {
continue
}

// Inherited functions are always from interfaces
interfaceType := member.ContainerType.(*sema.InterfaceType)
for _, conformance := range compositeType.EffectiveInterfaceConformances() {
interfaceType := conformance.InterfaceType

elaboration, err := d.config.ElaborationResolver(interfaceType.Location)
if err != nil {
panic(err)
}

interfaceDecl := elaboration.InterfaceTypeDeclaration(interfaceType)

functions := interfaceDecl.Members.FunctionsByIdentifier()
inheritedFunc, ok := functions[memberName]
if !ok {
panic(errors.NewUnreachableError())
}

// for each inherited function, generate a delegator function,
// which calls the actual default implementation at the interface.
// i.e:
// FooImpl {
// fun defaultFunc(a1: T1, a2: T2): R {
// return FooInterface.defaultFunc(a1, a2)
// }
// }
for funcName, inheritedFunc := range functions { // nolint:maprange
if !inheritedFunc.FunctionBlock.HasStatements() {
continue
}

// Pick the 'closest' default function.
// This is the same way how it is implemented in the interpreter.
_, ok := inheritedDefaultFunctions[funcName]
if ok {
continue
}
inheritedDefaultFunctions[funcName] = struct{}{}

// If the inherited function is overridden by the current type, then skip.
if d.isFunctionOverridden(compositeType, funcName, existingFunctions) {
continue
}

// Generate: `FooInterface.defaultFunc(a1, a2)`
// For each inherited function, generate a delegator function,
// which calls the actual default implementation at the interface.
// i.e:
// FooImpl {
// fun defaultFunc(a1: T1, a2: T2): R {
// return FooInterface.defaultFunc(a1, a2)
// }
// }

inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc)
// Generate: `FooInterface.defaultFunc(a1, a2)`

invocation := d.interfaceDelegationMethodCall(
interfaceType,
inheritedFuncType,
pos,
memberName,
member,
)
inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc)

// Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }`
defaultFuncDelegator := ast.NewFunctionDeclaration(
d.memoryGauge,
inheritedFunc.Access,
inheritedFunc.Purity,
inheritedFunc.IsStatic(),
inheritedFunc.IsNative(),
ast.NewIdentifier(
d.memoryGauge,
memberName,
member, ok := interfaceType.MemberMap().Get(funcName)
if !ok {
panic(errors.NewUnreachableError())
}

invocation := d.interfaceDelegationMethodCall(
interfaceType,
inheritedFuncType,
pos,
),
inheritedFunc.TypeParameterList,
inheritedFunc.ParameterList,
inheritedFunc.ReturnTypeAnnotation,
ast.NewFunctionBlock(
funcName,
member,
)

funcReturnType := inheritedFuncType.ReturnTypeAnnotation.Type
returnStmt := ast.NewReturnStatement(d.memoryGauge, invocation, declRange)
d.elaboration.SetReturnStatementTypes(
returnStmt,
sema.ReturnStatementTypes{
ValueType: funcReturnType,
ReturnType: funcReturnType,
},
)

// Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }`
defaultFuncDelegator := ast.NewFunctionDeclaration(
d.memoryGauge,
ast.NewBlock(
inheritedFunc.Access,
inheritedFunc.Purity,
inheritedFunc.IsStatic(),
inheritedFunc.IsNative(),
ast.NewIdentifier(
d.memoryGauge,
[]ast.Statement{
ast.NewReturnStatement(d.memoryGauge, invocation, decl.Range),
},
decl.Range,
funcName,
pos,
),
nil,
nil,
),
inheritedFunc.StartPos,
inheritedFunc.DocString,
)
inheritedFunc.TypeParameterList,
inheritedFunc.ParameterList,
inheritedFunc.ReturnTypeAnnotation,
ast.NewFunctionBlock(
d.memoryGauge,
ast.NewBlock(
d.memoryGauge,
[]ast.Statement{
returnStmt,
},
declRange,
),
nil,
nil,
),
inheritedFunc.StartPos,
inheritedFunc.DocString,
)

d.elaboration.SetFunctionDeclarationFunctionType(defaultFuncDelegator, inheritedFuncType)

inheritedMembers = append(inheritedMembers, defaultFuncDelegator)
// Pass the generated default function again through the desugar phase,
// so that it will properly link/chain the function conditions
// that are inherited/available for this default function.
desugaredDelegator := d.desugarDeclaration(defaultFuncDelegator)

inheritedMembers = append(inheritedMembers, desugaredDelegator)

}
}

return inheritedMembers
}

func (d *Desugar) isFunctionOverridden(
enclosingType sema.ConformingType,
funcName string,
existingFunctions map[string]*ast.FunctionDeclaration,
) bool {
implementedFunc, isImplemented := existingFunctions[funcName]
if !isImplemented {
return false
}

_, isInterface := enclosingType.(*sema.InterfaceType)
if isInterface {
// If the currently visiting declaration is an interface type (i.e: This function is an interface method)
// then it is considered as a default implementation only if there are statements.
// This is because interface methods can define conditions, without overriding the function.
return implementedFunc.FunctionBlock.HasStatements()
}

return true
}

func (d *Desugar) interfaceDelegationMethodCall(
interfaceType *sema.InterfaceType,
inheritedFuncType *sema.FunctionType,
Expand Down Expand Up @@ -1162,6 +1211,7 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio

prevModifiedDecls := d.modifiedDeclarations
prevEnclosingInterfaceType := d.enclosingInterfaceType

d.modifiedDeclarations = nil
d.enclosingInterfaceType = interfaceType

Expand All @@ -1170,26 +1220,30 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio
d.enclosingInterfaceType = prevEnclosingInterfaceType
}()

existingMembers := declaration.Members.Declarations()

// Recursively de-sugar nested declarations (functions, types, etc.)

membersDesugared := false

existingMembers := declaration.Members.Declarations()
for _, member := range existingMembers {
desugaredMember := d.desugarDeclaration(member)
membersDesugared = membersDesugared || (desugaredMember != member)
if desugaredMember == nil {
continue
}
d.modifiedDeclarations = append(d.modifiedDeclarations, desugaredMember)
}

// Optimization: If none of the existing members got updated or,
// Add inherited default functions.
existingFunctions := declaration.Members.FunctionsByIdentifier()
inheritedDefaultFuncs := d.inheritedDefaultFunctions(
interfaceType,
existingFunctions,
declaration.StartPos,
declaration.Range,
)

d.modifiedDeclarations = append(d.modifiedDeclarations, inheritedDefaultFuncs...)

// TODO: Optimize: If none of the existing members got updated or,
// if there are no inherited members, then return the same declaration as-is.
//if !membersDesugared && len(inheritedDefaultFuncs) == 0 {
// return declaration
//}
if !membersDesugared {
return declaration
}

modifiedDecl := ast.NewInterfaceDeclaration(
d.memoryGauge,
Expand Down
Loading

0 comments on commit 0ee77c8

Please sign in to comment.