diff --git a/src/lib/compiler/compiler.go b/src/lib/compiler/compiler.go index bcb1e1a..71d8204 100644 --- a/src/lib/compiler/compiler.go +++ b/src/lib/compiler/compiler.go @@ -201,10 +201,12 @@ func (c *Compiler) ImportAll(path string, ctx *Context) error { if s.Export.FunctionDefinition != nil { var params []*ir.Param for _, p := range s.Export.FunctionDefinition.Parameters { - params = append(params, ir.NewParam(p.Name, ctx.StringToType(p.Type))) + params = append(params, ir.NewParam(p.Name, ctx.CFTypeToLLType(p.Type))) + } + fn := c.Module.NewFunc(s.Export.FunctionDefinition.Name.Name, ctx.CFMultiTypeToLLType(s.Export.FunctionDefinition.ReturnType), params...) + if s.Export.FunctionDefinition.Variadic != "" { + fn.Sig.Variadic = true } - fn := c.Module.NewFunc(s.Export.FunctionDefinition.Name.Name, ctx.StringToType(s.Export.FunctionDefinition.ReturnType), params...) - fn.Sig.Variadic = s.Export.FunctionDefinition.Variadic ctx.SymbolTable[s.Export.FunctionDefinition.Name.Name] = fn } else if s.Export.ClassDefinition != nil { @@ -214,28 +216,29 @@ func (c *Compiler) ImportAll(path string, ctx *Context) error { ctx.structNames[cStruct] = s.Export.ClassDefinition.Name for _, st := range s.Export.ClassDefinition.Body { if st.FieldDefinition != nil { - cStruct.Fields = append(cStruct.Fields, ctx.StringToType(st.FieldDefinition.Type)) + cStruct.Fields = append(cStruct.Fields, ctx.CFTypeToLLType(st.FieldDefinition.Type)) ctx.Compiler.StructFields[s.Export.ClassDefinition.Name] = append(ctx.Compiler.StructFields[s.Export.ClassDefinition.Name], st.FieldDefinition) } else if st.FunctionDefinition != nil { f := st.FunctionDefinition var params []*ir.Param params = append(params, ir.NewParam("this", types.NewPointer(cStruct))) for _, arg := range f.Parameters { - params = append(params, ir.NewParam(arg.Name, ctx.StringToType(arg.Type))) + params = append(params, ir.NewParam(arg.Name, ctx.CFTypeToLLType(arg.Type))) } ms := "." + f.Name.Name if f.Name.Op { - ms = ".op." + strings.Trim(f.Name.String, "\"") + ms = ".op." + strings.Trim(f.Name.Name, "\"") } else if f.Name.Get { - ms = ".get." + strings.Trim(f.Name.String, "\"") + ms = ".get." + strings.Trim(f.Name.Name, "\"") } else if f.Name.Set { - ms = ".set." + strings.Trim(f.Name.String, "\"") + ms = ".set." + strings.Trim(f.Name.Name, "\"") } - fn := ctx.Module.NewFunc(s.Export.ClassDefinition.Name+ms, ctx.StringToType(f.ReturnType), params...) - fn.Sig.Variadic = st.FunctionDefinition.Variadic - fn.Sig.RetType = ctx.StringToType(f.ReturnType) + fn := ctx.Module.NewFunc(s.Export.ClassDefinition.Name+ms, ctx.CFMultiTypeToLLType(f.ReturnType), params...) + if st.FunctionDefinition.Variadic != "" { + fn.Sig.Variadic = true + } ctx.SymbolTable[s.Export.ClassDefinition.Name+ms] = fn } @@ -243,9 +246,9 @@ func (c *Compiler) ImportAll(path string, ctx *Context) error { } else if s.Export.External != nil { var params []*ir.Param for _, p := range s.Export.External.Parameters { - params = append(params, ir.NewParam(p.Name, ctx.StringToType(p.Type))) + params = append(params, ir.NewParam(p.Name, ctx.CFTypeToLLType(p.Type))) } - fn := c.Module.NewFunc(s.Export.External.Name, ctx.StringToType(s.Export.External.ReturnType), params...) + fn := c.Module.NewFunc(s.Export.External.Name, ctx.CFMultiTypeToLLType(s.Export.External.ReturnType), params...) fn.Sig.Variadic = s.Export.External.Variadic ctx.SymbolTable[s.Export.External.Name] = fn } else { @@ -283,9 +286,9 @@ func (c *Compiler) ImportAs(path string, symbols map[string]string, ctx *Context if newname, ok := symbols[s.Export.FunctionDefinition.Name.Name]; ok { var params []*ir.Param for _, p := range s.Export.FunctionDefinition.Parameters { - params = append(params, ir.NewParam(p.Name, ctx.StringToType(p.Type))) + params = append(params, ir.NewParam(p.Name, ctx.CFTypeToLLType(p.Type))) } - fn := c.Module.NewFunc(s.Export.FunctionDefinition.Name.Name, ctx.StringToType(s.Export.FunctionDefinition.ReturnType), params...) + fn := c.Module.NewFunc(s.Export.FunctionDefinition.Name.Name, ctx.CFMultiTypeToLLType(s.Export.FunctionDefinition.ReturnType), params...) if newname == "" { newname = s.Export.FunctionDefinition.Name.Name } @@ -299,26 +302,27 @@ func (c *Compiler) ImportAs(path string, symbols map[string]string, ctx *Context cStruct := types.NewStruct() for _, st := range s.Export.ClassDefinition.Body { if st.FieldDefinition != nil { - cStruct.Fields = append(cStruct.Fields, ctx.StringToType(st.FieldDefinition.Type)) + cStruct.Fields = append(cStruct.Fields, ctx.CFTypeToLLType(st.FieldDefinition.Type)) } else if st.FunctionDefinition != nil { var params []*ir.Param for _, p := range st.FunctionDefinition.Parameters { - params = append(params, ir.NewParam(p.Name, ctx.StringToType(p.Type))) + params = append(params, ir.NewParam(p.Name, ctx.CFTypeToLLType(p.Type))) } f := st.FunctionDefinition ms := "." + f.Name.Name if f.Name.Op { - ms = ".op." + strings.Trim(f.Name.String, "\"") + ms = ".op." + strings.Trim(f.Name.Name, "\"") } else if f.Name.Get { - ms = ".get." + strings.Trim(f.Name.String, "\"") + ms = ".get." + strings.Trim(f.Name.Name, "\"") } else if f.Name.Set { - ms = ".set." + strings.Trim(f.Name.String, "\"") + ms = ".set." + strings.Trim(f.Name.Name, "\"") } - fn := ctx.Module.NewFunc(s.Export.ClassDefinition.Name+ms, ctx.StringToType(f.ReturnType), params...) - fn.Sig.Variadic = false - fn.Sig.RetType = ctx.StringToType(f.ReturnType) + fn := ctx.Module.NewFunc(s.Export.ClassDefinition.Name+ms, ctx.CFMultiTypeToLLType(f.ReturnType), params...) + if st.FunctionDefinition.Variadic != "" { + fn.Sig.Variadic = true + } ctx.SymbolTable[s.Export.ClassDefinition.Name+ms] = fn } @@ -329,9 +333,9 @@ func (c *Compiler) ImportAs(path string, symbols map[string]string, ctx *Context } else if s.Export.External != nil { var params []*ir.Param for _, p := range s.Export.External.Parameters { - params = append(params, ir.NewParam(p.Name, ctx.StringToType(p.Type))) + params = append(params, ir.NewParam(p.Name, ctx.CFTypeToLLType(p.Type))) } - fn := c.Module.NewFunc(s.Export.External.Name, ctx.StringToType(s.Export.External.ReturnType), params...) + fn := c.Module.NewFunc(s.Export.External.Name, ctx.CFMultiTypeToLLType(s.Export.External.ReturnType), params...) ctx.SymbolTable[s.Export.External.Name] = fn } else { continue diff --git a/src/lib/compiler/expressions.go b/src/lib/compiler/expressions.go index 7d7cd7f..0596996 100644 --- a/src/lib/compiler/expressions.go +++ b/src/lib/compiler/expressions.go @@ -16,219 +16,377 @@ import ( ) func (ctx *Context) compileExpression(e *parser.Expression) (value.Value, error) { - left, err := ctx.compileComparison(e.Left) + cond, err := ctx.compileLogicalOr(e.Condition) if err != nil { return nil, err } + + if e.True != nil && e.False != nil { + if cond.Type() != types.I1 { + return nil, posError(e.Condition.Pos, "condition in ternary expression must be a boolean") + } + + trueVal, err := ctx.compileExpression(e.True) + if err != nil { + return nil, err + } + + falseVal, err := ctx.compileExpression(e.False) + if err != nil { + return nil, err + } + + if trueVal.Type() != falseVal.Type() { + return nil, posError(e.Pos, "true and false expressions in ternary expression must be the same type") + } + + return ctx.NewSelect(cond, trueVal, falseVal), nil + } + + return cond, nil +} + +func (ctx *Context) compileLogicalAnd(l *parser.LogicalAnd) (value.Value, error) { + left, err := ctx.compileBitwiseOr(l.Left) + if err != nil { + return nil, err + } + + if len(l.Right) != 0 && left.Type() != types.I1 { + return nil, posError(l.Left.Pos, "logical and operator requires boolean operands") + } + + for _, right := range l.Right { + rightVal, err := ctx.compileLogicalAnd(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if rightVal.Type() != types.I1 { + return nil, posError(right.Pos, "logical and operator requires boolean operands") + } + + left = ctx.NewAnd(left, rightVal) + } + + return left, nil +} + +func (ctx *Context) compileLogicalOr(l *parser.LogicalOr) (value.Value, error) { + left, err := ctx.compileLogicalAnd(l.Left) + if err != nil { + return nil, err + } + + if len(l.Right) != 0 && left.Type() != types.I1 { + return nil, posError(l.Left.Pos, "logical or operator requires boolean operands") + } + + for _, right := range l.Right { + rightVal, err := ctx.compileLogicalOr(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if rightVal.Type() != types.I1 { + return nil, posError(right.Pos, "logical or operator requires boolean operands") + } + + left = ctx.NewOr(left, rightVal) + } + + return left, nil +} + +func (ctx *Context) compileBitwiseAnd(b *parser.BitwiseAnd) (value.Value, error) { + left, err := ctx.compileEquality(b.Left) + if err != nil { + return nil, err + } + + if _, ok := left.Type().(*types.IntType); len(b.Right) != 0 && !ok { + return nil, posError(b.Left.Pos, "bitwise and operator requires integer operands") + } + + for _, right := range b.Right { + rightVal, err := ctx.compileBitwiseAnd(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if _, ok := rightVal.Type().(*types.IntType); !ok { + return nil, posError(right.Pos, "bitwise and operator requires integer operands") + } + + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) + } + + left = ctx.NewAnd(left, rightVal) + } + + return left, nil +} + +func (ctx *Context) compileBitwiseXor(b *parser.BitwiseXor) (value.Value, error) { + left, err := ctx.compileBitwiseAnd(b.Left) + if err != nil { + return nil, err + } + + if _, ok := left.Type().(*types.IntType); len(b.Right) != 0 && !ok { + return nil, posError(b.Left.Pos, "bitwise xor operator requires integer operands") + } + + for _, right := range b.Right { + rightVal, err := ctx.compileBitwiseXor(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if _, ok := rightVal.Type().(*types.IntType); !ok { + return nil, posError(right.Pos, "bitwise xor operator requires integer operands") + } + + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) + } + + left = ctx.NewXor(left, rightVal) + } + + return left, nil +} + +func (ctx *Context) compileBitwiseOr(b *parser.BitwiseOr) (value.Value, error) { + left, err := ctx.compileBitwiseXor(b.Left) + if err != nil { + return nil, err + } + + if _, ok := left.Type().(*types.IntType); len(b.Right) != 0 && !ok { + return nil, posError(b.Left.Pos, "bitwise or operator requires integer operands") + } + + for _, right := range b.Right { + rightVal, err := ctx.compileBitwiseOr(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if _, ok := rightVal.Type().(*types.IntType); !ok { + return nil, posError(right.Pos, "bitwise or operator requires integer operands") + } + + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) + } + + left = ctx.NewOr(left, rightVal) + } + + return left, nil +} + +func (ctx *Context) compileEquality(e *parser.Equality) (value.Value, error) { + left, err := ctx.compileRelational(e.Left) + if err != nil { + return nil, err + } + for _, right := range e.Right { - ctx.RequestedType = left.Type() - rightVal, err := ctx.compileComparison(right.Expression) + rightVal, err := ctx.compileEquality(right) if err != nil { return nil, err } - if !left.Type().Equal(rightVal.Type()) { - targetType := left.Type() + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } - // Predefined bitcasts - switch targetType { - case types.Double: - if rightVal.Type().Equal(types.Float) { - rightVal = ctx.NewFPExt(rightVal, types.Double) - } - case types.Float: - if rightVal.Type().Equal(types.Double) { - rightVal = ctx.NewFPTrunc(rightVal, types.Float) - } - } + if left.Type() != rightVal.Type() { + return nil, posError(e.Left.Pos, "operands must be the same type") + } - if valType, ok := rightVal.Type().(*types.IntType); ok { - if target, ok := targetType.(*types.IntType); ok { - if valType.BitSize < target.BitSize { - // Extend if valType is smaller than targetType - rightVal = ctx.NewSExt(rightVal, targetType) - } else if valType.BitSize > target.BitSize { - // Truncate if valType is larger than targetType - rightVal = ctx.NewTrunc(rightVal, targetType) - } - } else if targetType == types.Float { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Double { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Half { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.FP128 { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } + switch right.Op { + case "==": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredOEQ, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredEQ, left, rightVal) } - - // If the value is a struct type or a pointer to a struct type, try to find a conversion function - if structType, ok := rightVal.Type().(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } else if ptrType, ok := rightVal.Type().(*types.PointerType); ok { - if structType, ok := ptrType.ElemType.(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } + case "!=": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredONE, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredNE, left, rightVal) } + default: + return nil, posError(right.Pos, "unknown equality operator: %s", right.Op) + } + } - if !rightVal.Type().Equal(targetType) { - return nil, posError(right.Pos, "Automated conversion from %s to %s failed.", rightVal.Type(), targetType) - } + return left, nil +} + +func (ctx *Context) compileRelational(r *parser.Relational) (value.Value, error) { + left, err := ctx.compileShift(r.Left) + if err != nil { + return nil, err + } + + if len(r.Right) != 0 && !isNumeric(left.Type()) { + return nil, posError(r.Left.Pos, "relational operator requires numeric operands") + } + + lrop := r.Op + for _, right := range r.Right { + rightVal, err := ctx.compileRelational(right) + if err != nil { + return nil, err } - switch leftType := left.(type) { - case *ir.InstLoad, *ir.InstCall, *ir.InstAlloca: - if structType, ok := leftType.Type().(*types.PointerType); ok { - if _, ok := structType.ElemType.(*types.StructType); ok { - // Check if the class has a method with the name "classname.op.operator" - methodName := fmt.Sprintf("%s.op.%s", structType.ElemType.Name(), right.Op) - if method, ok := ctx.lookupFunction(methodName); ok { - // Call the method and use its result as the result - left = ctx.NewCall(method, left, rightVal) - continue - } - } - } + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) } - switch rightVal.Type() { - case types.Float: - if !left.Type().Equal(types.Float) { - leftFloat := ctx.NewSIToFP(left, rightVal.Type()) - left = leftFloat - } - switch right.Op { - case "+": - left = ctx.NewFAdd(left, rightVal) - case "-": - left = ctx.NewFSub(left, rightVal) - case "&&": - left = ctx.NewAnd(left, rightVal) - case "||": - left = ctx.NewOr(left, rightVal) - default: - return nil, posError(right.Pos, "Unknown expression operator: %s", right.Op) + if !isNumeric(rightVal.Type()) { + return nil, posError(right.Pos, "relational operator requires numeric operands") + } + + if !left.Type().Equal(rightVal.Type()) { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) + } + + switch lrop { + case "<=": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredOLE, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredSLE, left, rightVal) } - case types.Double: - if !left.Type().Equal(types.Double) { - leftFloat := ctx.NewSIToFP(left, rightVal.Type()) - left = leftFloat + case ">=": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredOGE, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredSGE, left, rightVal) } - - switch right.Op { - case "+": - left = ctx.NewFAdd(left, rightVal) - case "-": - left = ctx.NewFSub(left, rightVal) - case "&&": - left = ctx.NewAnd(left, rightVal) - case "||": - left = ctx.NewOr(left, rightVal) - default: - return nil, posError(right.Pos, "Unknown expression operator: %s", right.Op) + case "<": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredOLT, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredSLT, left, rightVal) } - default: - switch right.Op { - case "+": - left = ctx.NewAdd(left, rightVal) - case "-": - left = ctx.NewSub(left, rightVal) - case "&&": - left = ctx.NewAnd(left, rightVal) - case "||": - left = ctx.NewOr(left, rightVal) - default: - return nil, posError(right.Pos, "Unknown expression operator: %s", right.Op) + case ">": + if types.IsFloat(left.Type()) { + left = ctx.NewFCmp(enum.FPredOGT, left, rightVal) + } else { + left = ctx.NewICmp(enum.IPredSGT, left, rightVal) } + default: + return nil, posError(right.Pos, "unknown relational operator: %s", lrop) } + lrop = right.Op } - ctx.RequestedType = nil + return left, nil } -func (ctx *Context) compileComparison(c *parser.Comparison) (value.Value, error) { - left, err := ctx.compileTerm(c.Left) +func (ctx *Context) compileShift(s *parser.Shift) (value.Value, error) { + left, err := ctx.compileAdditive(s.Left) if err != nil { return nil, err } - for _, right := range c.Right { - ctx.RequestedType = left.Type() - rightVal, err := ctx.compileTerm(right.Comparison) + + if _, ok := left.Type().(*types.IntType); len(s.Right) != 0 && !ok { + return nil, posError(s.Left.Pos, "shift operator requires integer operands") + } + + for _, right := range s.Right { + rightVal, err := ctx.compileShift(right) if err != nil { return nil, err } - if !left.Type().Equal(rightVal.Type()) { - targetType := left.Type() + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } - // Predefined bitcasts - switch targetType { - case types.Double: - if rightVal.Type().Equal(types.Float) { - rightVal = ctx.NewFPExt(rightVal, types.Double) - } - case types.Float: - if rightVal.Type().Equal(types.Double) { - rightVal = ctx.NewFPTrunc(rightVal, types.Float) - } - } + if _, ok := rightVal.Type().(*types.IntType); !ok { + return nil, posError(right.Pos, "shift operator requires integer operands") + } - if valType, ok := rightVal.Type().(*types.IntType); ok { - if target, ok := targetType.(*types.IntType); ok { - if valType.BitSize < target.BitSize { - // Extend if valType is smaller than targetType - rightVal = ctx.NewSExt(rightVal, targetType) - } else if valType.BitSize > target.BitSize { - // Truncate if valType is larger than targetType - rightVal = ctx.NewTrunc(rightVal, targetType) - } - } else if targetType == types.Float { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Double { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Half { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.FP128 { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } - } + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) + } - // If the value is a struct type or a pointer to a struct type, try to find a conversion function - if structType, ok := rightVal.Type().(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } else if ptrType, ok := rightVal.Type().(*types.PointerType); ok { - if structType, ok := ptrType.ElemType.(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } - } + switch right.Op { + case "<<": + left = ctx.NewShl(left, rightVal) + case ">>", ">>>": + left = ctx.NewLShr(left, rightVal) + default: + return nil, posError(right.Pos, "unknown shift operator: %s", right.Op) + } + } - if !rightVal.Type().Equal(targetType) { - return nil, posError(right.Pos, "Automated conversion from %s to %s failed.", rightVal.Type(), targetType) - } + return left, nil +} + +func (ctx *Context) compileAdditive(a *parser.Additive) (value.Value, error) { + left, err := ctx.compileMultiplicative(a.Left) + if err != nil { + return nil, err + } + + if len(a.Right) != 0 && !isNumeric(left.Type()) { + return nil, posError(a.Left.Pos, "additive operator requires numeric operands") + } + + for _, right := range a.Right { + rightVal, err := ctx.compileAdditive(right) + if err != nil { + return nil, err + } + + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } + + if !isNumeric(rightVal.Type()) { + return nil, posError(right.Pos, "additive operator requires numeric operands") + } + + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) } switch leftType := left.(type) { case *ir.InstLoad, *ir.InstCall, *ir.InstAlloca: - if structType, ok := leftType.Type().(*types.PointerType); ok { - if _, ok := structType.ElemType.(*types.StructType); ok { + if ptrType, ok := leftType.Type().(*types.PointerType); ok { + if structType, ok := ptrType.ElemType.(*types.StructType); ok { // Check if the class has a method with the name "classname.op.operator" - methodName := fmt.Sprintf("%s.op.%s", structType.ElemType.Name(), right.Op) + methodName := fmt.Sprintf("%s.op.%s", structType.Name(), a.Op) if method, ok := ctx.lookupFunction(methodName); ok { // Call the method and use its result as the result left = ctx.NewCall(method, left, rightVal) @@ -238,102 +396,61 @@ func (ctx *Context) compileComparison(c *parser.Comparison) (value.Value, error) } } - switch right.Op { - case "==": - left = ctx.NewICmp(enum.IPredEQ, left, rightVal) - case "!=": - left = ctx.NewICmp(enum.IPredNE, left, rightVal) - case ">": - left = ctx.NewICmp(enum.IPredSGT, left, rightVal) - case "<": - left = ctx.NewICmp(enum.IPredSLT, left, rightVal) - case ">=": - left = ctx.NewICmp(enum.IPredSGE, left, rightVal) - case "<=": - left = ctx.NewICmp(enum.IPredSLE, left, rightVal) + switch a.Op { + case "+": + if types.IsFloat(left.Type()) { + left = ctx.NewFAdd(left, rightVal) + } else { + left = ctx.NewAdd(left, rightVal) + } + case "-": + if types.IsFloat(left.Type()) { + left = ctx.NewFSub(left, rightVal) + } else { + left = ctx.NewSub(left, rightVal) + } default: - return nil, posError(right.Pos, "Unknown comparison operator: %s", right.Op) + return nil, posError(right.Pos, "unknown additive operator: %s", a.Op) } } - ctx.RequestedType = nil + return left, nil } -func (ctx *Context) compileTerm(t *parser.Term) (value.Value, error) { - left, err := ctx.compileFactor(t.Left) +func (ctx *Context) compileMultiplicative(m *parser.Multiplicative) (value.Value, error) { + left, err := ctx.compileLogicalNot(m.Left) if err != nil { return nil, err } - for _, right := range t.Right { - ctx.RequestedType = left.Type() - rightVal, err := ctx.compileFactor(right.Term) + + if len(m.Right) != 0 && !isNumeric(left.Type()) { + return nil, posError(m.Left.Pos, "multiplicative operator requires numeric operands") + } + + for _, right := range m.Right { + rightVal, err := ctx.compileMultiplicative(right) if err != nil { return nil, err } - if !left.Type().Equal(rightVal.Type()) { - targetType := left.Type() - - // Predefined bitcasts - switch targetType { - case types.Double: - if rightVal.Type().Equal(types.Float) { - rightVal = ctx.NewFPExt(rightVal, types.Double) - } - case types.Float: - if rightVal.Type().Equal(types.Double) { - rightVal = ctx.NewFPTrunc(rightVal, types.Float) - } - } - - if valType, ok := rightVal.Type().(*types.IntType); ok { - if target, ok := targetType.(*types.IntType); ok { - if valType.BitSize < target.BitSize { - // Extend if valType is smaller than targetType - rightVal = ctx.NewSExt(rightVal, targetType) - } else if valType.BitSize > target.BitSize { - // Truncate if valType is larger than targetType - rightVal = ctx.NewTrunc(rightVal, targetType) - } - } else if targetType == types.Float { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Double { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.Half { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } else if targetType == types.FP128 { - rightVal = ctx.NewSIToFP(rightVal, targetType) - } - } + if ptrType, ok := rightVal.Type().(*types.PointerType); ok && ptrType.ElemType == left.Type() { + rightVal = ctx.NewLoad(ptrType.ElemType, rightVal) + } - // If the value is a struct type or a pointer to a struct type, try to find a conversion function - if structType, ok := rightVal.Type().(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } else if ptrType, ok := rightVal.Type().(*types.PointerType); ok { - if structType, ok := ptrType.ElemType.(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + targetType.Name()) - if ok { - // If a conversion function is found, call it and return the result - rightVal = ctx.NewCall(method, rightVal) - } - } - } + if !isNumeric(rightVal.Type()) { + return nil, posError(right.Pos, "multiplicative operator requires numeric operands") + } - if !rightVal.Type().Equal(targetType) { - return nil, posError(right.Pos, "Automated conversion from %s to %s failed.", rightVal.Type(), targetType) - } + if left.Type() != rightVal.Type() { + return nil, posError(right.Pos, "operands must be the same type (%s != %s)", left.Type(), rightVal.Type()) } switch leftType := left.(type) { case *ir.InstLoad, *ir.InstCall, *ir.InstAlloca: - if structType, ok := leftType.Type().(*types.PointerType); ok { - if _, ok := structType.ElemType.(*types.StructType); ok { + if ptrType, ok := leftType.Type().(*types.PointerType); ok { + if structType, ok := ptrType.ElemType.(*types.StructType); ok { // Check if the class has a method with the name "classname.op.operator" - methodName := fmt.Sprintf("%s.op.%s", structType.ElemType.Name(), right.Op) + methodName := fmt.Sprintf("%s.op.%s", structType.Name(), right.Op) if method, ok := ctx.lookupFunction(methodName); ok { // Call the method and use its result as the result left = ctx.NewCall(method, left, rightVal) @@ -363,10 +480,100 @@ func (ctx *Context) compileTerm(t *parser.Term) (value.Value, error) { left = ctx.NewSRem(left, rightVal) } default: - return nil, posError(right.Pos, "Unknown term operator: %s", right.Op) + return nil, posError(right.Pos, "unknown multiplicative operator: %s", right.Op) + } + } + + return left, nil +} + +func (ctx *Context) compileLogicalNot(l *parser.LogicalNot) (value.Value, error) { + right, err := ctx.compileBitwiseNot(l.Right) + if err != nil { + return nil, err + } + + if l.Op != "" { + if right.Type() != types.I1 { + return nil, posError(l.Right.Pos, "logical not operator requires a boolean operand") + } + right = ctx.NewXor(right, constant.NewInt(types.I1, 1)) + } + + return right, nil +} + +func (ctx *Context) compileBitwiseNot(b *parser.BitwiseNot) (value.Value, error) { + right, err := ctx.compilePrefixAdditive(b.Right) + if err != nil { + return nil, err + } + + if b.Op != "" { + intType, ok := right.Type().(*types.IntType) + if !ok { + return nil, posError(b.Right.Pos, "bitwise not operator requires an integer operand") + } + + mask := constant.NewInt(intType, -1) + right = ctx.NewXor(right, mask) + } + + return right, nil +} + +func (ctx *Context) compilePrefixAdditive(p *parser.PrefixAdditive) (value.Value, error) { + right, err := ctx.compilePostfixAdditive(p.Right) + if err != nil { + return nil, err + } + + if p.Op != "" { + if _, ok := right.Type().(*types.FloatType); ok { + if p.Op == "++" { + return ctx.NewFAdd(right, constant.NewFloat(types.Float, 1)), nil + } else { + return ctx.NewFSub(right, constant.NewFloat(types.Float, 1)), nil + } + } else { + if p.Op == "++" { + return ctx.NewAdd(right, constant.NewInt(types.I8, 1)), nil + } else { + return ctx.NewSub(right, constant.NewInt(types.I8, 1)), nil + } } } - ctx.RequestedType = nil + + return right, nil +} + +func (ctx *Context) compilePostfixAdditive(p *parser.PostfixAdditive) (value.Value, error) { + left, err := ctx.compileFactor(p.Left) + if err != nil { + return nil, err + } + + if p.Op != "" { + original := left + if _, ok := left.Type().(*types.FloatType); ok { + if p.Op == "++" { + left = ctx.NewFAdd(left, constant.NewFloat(types.Float, 1)) + } else { + left = ctx.NewFSub(left, constant.NewFloat(types.Float, 1)) + } + } else { + if p.Op == "++" { + left = ctx.NewAdd(left, constant.NewInt(types.I8, 1)) + } else { + left = ctx.NewSub(left, constant.NewInt(types.I8, 1)) + } + } + if _, ok := original.Type().(*types.PointerType); ok { + ctx.NewStore(left, original) + } + return original, nil + } + return left, nil } @@ -412,11 +619,11 @@ func (ctx *Context) compileBitCast(bc *parser.BitCast) (value.Value, error) { return nil, err } - if bc.Type == "" { + if bc.Type == nil { return val, nil } - targetType := ctx.StringToType(bc.Type) + targetType := ctx.CFTypeToLLType(bc.Type) // If the value is already of the target type, just return it if val.Type().Equal(targetType) { @@ -455,7 +662,7 @@ func (ctx *Context) compileBitCast(bc *parser.BitCast) (value.Value, error) { // If the value is a struct type or a pointer to a struct type, try to find a conversion function if structType, ok := val.Type().(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + bc.Type) + method, ok := ctx.lookupFunction(structType.Name() + ".get." + ctx.CFTypeToLLType(bc.Type).Name()) if ok { // If a conversion function is found, call it and return the result result := ctx.NewCall(method, val) @@ -463,7 +670,7 @@ func (ctx *Context) compileBitCast(bc *parser.BitCast) (value.Value, error) { } } else if ptrType, ok := val.Type().(*types.PointerType); ok { if structType, ok := ptrType.ElemType.(*types.StructType); ok { - method, ok := ctx.lookupFunction(structType.Name() + ".get." + bc.Type) + method, ok := ctx.lookupFunction(structType.Name() + ".get." + ctx.CFTypeToLLType(bc.Type).Name()) if ok { // If a conversion function is found, call it and return the result result := ctx.NewCall(method, val) @@ -478,7 +685,7 @@ func (ctx *Context) compileBitCast(bc *parser.BitCast) (value.Value, error) { return bitcast, nil } - return nil, posError(bc.Pos, "Cannot convert %s to %s", val.Type().Name(), bc.Type) + return nil, posError(bc.Pos, "Cannot convert %s to %s", val.Type().Name(), ctx.CFTypeToLLType(bc.Type).Name()) } func (ctx *Context) compileClassInitializer(ci *parser.ClassInitializer) (value.Value, error) { @@ -615,25 +822,6 @@ func (ctx *Context) compileValue(v *parser.Value) (value.Value, error) { strGlobal.Immutable = true strGlobal.Linkage = enum.LinkagePrivate return strGlobal, nil - } else if v.Duration != nil { - var factor float64 - switch v.Duration.Unit { - case "h": - factor = 3600 - case "m": - factor = 60 - case "s": - factor = 1 - case "ms": - factor = 0.001 - case "us": - factor = 0.000001 - case "ns": - factor = 0.000000001 - default: - return nil, posError(v.Pos, "Unknown duration unit: %s", v.Duration.Unit) - } - return constant.NewFloat(types.Double, v.Duration.Number*factor), nil } else if v.Null { return constant.NewNull(types.I8Ptr), nil } else { @@ -656,9 +844,18 @@ func (ctx *Context) compileIdentifier(i *parser.Identifier, returnTopLevelStruct } ctx.RequestedType = nil - // Run GetElementPtr on the loaded value - v := ctx.NewGetElementPtr(val.Type.(*types.PointerType).ElemType, val.Value, gepExpr) - return v, v.Type(), nil + var elementType types.Type + switch t := val.Type.(type) { + case *types.PointerType: + elementType = t.ElemType + case *types.ArrayType: + elementType = t.ElemType + default: + return nil, nil, posError(i.GEP.Pos, "unsupported type for GetElementPtr: %s", t) + } + + v := ctx.NewGetElementPtr(elementType, val.Value, gepExpr) + return v, elementType, nil } // Handle referencing for j := 0; j < len(i.Ref); j++ { @@ -770,7 +967,7 @@ func (ctx *Context) compileSubIdentifier(f *Variable, sub *parser.Identifier) (F return nil, nil, false, posError(sub.Pos, "Field %s not found in struct %s", sub.Name, elemtypename) } - fieldPtr := ctx.NewGetElementPtr(ctx.StringToType(field.Type), f.Value, constant.NewInt(types.I32, int64(nfield))) + fieldPtr := ctx.NewGetElementPtr(ctx.CFTypeToLLType(field.Type), f.Value, constant.NewInt(types.I32, int64(nfield))) if sub.GEP != nil { ctx.RequestedType = types.I32 gepExpr, err := ctx.compileExpression(sub.GEP) diff --git a/src/lib/compiler/header.go b/src/lib/compiler/header.go index a14592d..30575d7 100644 --- a/src/lib/compiler/header.go +++ b/src/lib/compiler/header.go @@ -137,7 +137,7 @@ func WriteHeader(f *os.File, comp *Compiler) error { if !field.Private { continue } - _, err = f.WriteString(convertCffTypeToCType(comp.Context.StringToType(field.Type)) + " " + field.Name + ";\n") + _, err = f.WriteString(convertCffTypeToCType(comp.Context.CFTypeToLLType(field.Type)) + " " + field.Name + ";\n") if err != nil { return err } @@ -152,7 +152,7 @@ func WriteHeader(f *os.File, comp *Compiler) error { if field.Private { continue } - _, err = f.WriteString(convertCffTypeToCType(comp.Context.StringToType(field.Type)) + " " + field.Name + ";\n") + _, err = f.WriteString(convertCffTypeToCType(comp.Context.CFTypeToLLType(field.Type)) + " " + field.Name + ";\n") if err != nil { return err } diff --git a/src/lib/compiler/statements.go b/src/lib/compiler/statements.go index 110be7a..4e95dbb 100644 --- a/src/lib/compiler/statements.go +++ b/src/lib/compiler/statements.go @@ -6,6 +6,7 @@ import ( "github.com/fatih/color" "github.com/llir/llvm/ir" "github.com/llir/llvm/ir/constant" + "github.com/llir/llvm/ir/enum" "github.com/llir/llvm/ir/types" "github.com/llir/llvm/ir/value" "github.com/urfave/cli/v2" @@ -17,8 +18,7 @@ func (ctx *Context) compileStatement(s *parser.Statement) error { _, _, _, err := ctx.compileVariableDefinition(s.VariableDefinition) return err } else if s.Assignment != nil { - _, _, err := ctx.compileAssignment(s.Assignment) - return err + return ctx.compileAssignment(s.Assignment) } else if s.FunctionDefinition != nil { _, _, _, err := ctx.compileFunctionDefinition(s.FunctionDefinition) return err @@ -31,6 +31,8 @@ func (ctx *Context) compileStatement(s *parser.Statement) error { return ctx.compileFor(s.For) } else if s.While != nil { return ctx.compileWhile(s.While) + } else if s.Until != nil { + return ctx.compileUntil(s.Until) } else if s.Return != nil { return ctx.compileReturn(s.Return) } else if s.Break != nil { @@ -45,43 +47,37 @@ func (ctx *Context) compileStatement(s *parser.Statement) error { } else if s.External != nil { ctx.compileExternalFunction(s.External) } else if s.Import != nil { - //return ctx.Compiler.ImportAll(s.Import.Package, ctx) + return ctx.Compiler.ImportAll(s.Import.Package, ctx) } else if s.FromImport != nil { - /* - symbols := map[string]string{strings.Trim(s.FromImport.Symbol, "\""): strings.Trim(s.FromImport.Symbol, "\"")} - ctx.Compiler.ImportAs(s.FromImport.Package, symbols, ctx) - */ + symbols := map[string]string{strings.Trim(s.FromImport.Symbol, "\""): strings.Trim(s.FromImport.Symbol, "\"")} + ctx.Compiler.ImportAs(s.FromImport.Package, symbols, ctx) } else if s.FromImportMultiple != nil { - /* - symbols := map[string]string{} - for _, symbol := range s.FromImportMultiple.Symbols { - if symbol.Alias == "" { - symbol.Alias = symbol.Name - } - symbols[strings.Trim(symbol.Name, "\"")] = strings.Trim(symbol.Alias, "\"") + symbols := map[string]string{} + for _, symbol := range s.FromImportMultiple.Symbols { + if symbol.Alias == "" { + symbol.Alias = symbol.Name } - ctx.Compiler.ImportAs(s.FromImportMultiple.Package, symbols, ctx) - */ + symbols[strings.Trim(symbol.Name, "\"")] = strings.Trim(symbol.Alias, "\"") + } + ctx.Compiler.ImportAs(s.FromImportMultiple.Package, symbols, ctx) } else if s.Export != nil { return ctx.compileStatement(s.Export) } else if s.Comment != nil { return nil - } else { - return posError(s.Pos, "Unknown statement") } return nil } func (ctx *Context) compileExternalFunction(v *parser.ExternalFunctionDefinition) { var retType types.Type - if v.ReturnType == "" { + if len(v.ReturnType) == 0 { retType = types.Void } else { - retType = ctx.StringToType(v.ReturnType) + retType = ctx.CFMultiTypeToLLType(v.ReturnType) } var args []*ir.Param for _, arg := range v.Parameters { - args = append(args, ir.NewParam(arg.Name, ctx.StringToType(arg.Type))) + args = append(args, ir.NewParam(arg.Name, ctx.CFTypeToLLType(arg.Type))) } v.Name = strings.Trim(v.Name, "\"") @@ -92,7 +88,27 @@ func (ctx *Context) compileExternalFunction(v *parser.ExternalFunctionDefinition func (ctx *Context) compileVariableDefinition(v *parser.VariableDefinition) (Name string, Type types.Type, Value value.Value, Err error) { // If there is no assignment, create an uninitialized variable - valType := ctx.StringToType(v.Type) + valType := ctx.CFTypeToLLType(v.Type) + + if v.Constant == "const" { + if v.Assignment == nil { + return "", nil, nil, posError(v.Pos, "Constant definition must have assignment") + } + + cVal, err := ctx.compileExpression(v.Assignment) + if err != nil { + return "", nil, nil, err + } + + ctx.vars[v.Name] = &Variable{ + Name: v.Name, + Type: valType, + Value: cVal, + } + + return v.Name, valType, cVal, nil + } + if v.Assignment == nil { alloc := ctx.NewAlloca(valType) ctx.NewStore(constant.NewZeroInitializer(valType), alloc) @@ -144,60 +160,168 @@ func (ctx *Context) compileVariableDefinition(v *parser.VariableDefinition) (Nam return v.Name, alloc.Type(), alloc, nil } -func (ctx *Context) compileAssignment(a *parser.Assignment) (Name string, Value value.Value, Err error) { - // Compile the identifier to get the variable - variable, vType, err := ctx.compileIdentifier(a.Left, false) - if err != nil { - return "", nil, err - } - - if ptr, ok := variable.(*ir.InstGetElementPtr); ok { - /* - fmt.Println("This now") - fmt.Println("vType: ", vType) - fmt.Println("ElemType: ", ptr.ElemType) - */ - ctx.RequestedType = ptr.ElemType - } else if ptr, ok := vType.(*types.PointerType); ok { - ctx.RequestedType = ptr.ElemType - } else { - ctx.RequestedType = vType +func (ctx *Context) compileAssignment(a *parser.Assignment) (Err error) { + type Ident struct { + Value value.Value + Type types.Type } + var idents = make([]Ident, len(a.Idents)) + + for index, ident := range a.Idents { + i, t, err := ctx.compileIdentifier(ident, false) + if err != nil { + return err + } + + if a.Op != "=" && !isNumeric(t) { + return posError(ident.Pos, "Numeric operator used on non-numeric identifier %s", ident.Name) + } + + idents[index] = Ident{Value: i, Type: t} + } + + ctx.RequestedType = idents[0].Type val, err := ctx.compileExpression(a.Right) if err != nil { - return "", nil, err + return err } ctx.RequestedType = nil - ptr, ok := variable.(*ir.InstGetElementPtr) - if !ok { - aptr, ok := variable.(*ir.InstAlloca) - if !ok { - ctx.vars[a.Left.Name] = &Variable{ - Name: a.Left.Name, - Type: vType, - Value: val, + if a.Op != "=" { + if !isNumeric(val.Type()) { + return posError(a.Right.Pos, "Numeric operator used on non-numeric value") + } + + for i, ident := range idents { + _, isFloat := ident.Value.Type().(*types.FloatType) + var v value.Value + switch a.Op { + case "+=": + if isFloat { + v = ctx.NewFAdd(ident.Value, val) + } else { + v = ctx.NewAdd(ident.Value, val) + } + case "-=": + if isFloat { + v = ctx.NewFSub(ident.Value, val) + } else { + v = ctx.NewSub(ident.Value, val) + } + case "*=": + if isFloat { + v = ctx.NewFMul(ident.Value, val) + } else { + v = ctx.NewMul(ident.Value, val) + } + case "/=": + if isFloat { + v = ctx.NewFDiv(ident.Value, val) + } else { + v = ctx.NewSDiv(ident.Value, val) + } + case "%=": + if isFloat { + return posError(a.Pos, "Modulus operator not allowed on float") + } + v = ctx.NewSRem(ident.Value, val) + case "&=": + v = ctx.NewAnd(ident.Value, val) + case "|=": + v = ctx.NewOr(ident.Value, val) + case "^=": + v = ctx.NewXor(ident.Value, val) + case "<<=": + v = ctx.NewShl(ident.Value, val) + case ">>=": + v = ctx.NewLShr(ident.Value, val) + case ">>>=": + v = ctx.NewAShr(ident.Value, val) + case "??=": + isNull := ctx.NewICmp(enum.IPredEQ, ident.Value, constant.NewNull(ident.Value.Type().(*types.PointerType))) + v = ctx.NewSelect(isNull, val, ident.Value) + } + + ptr, ok := ident.Value.(*ir.InstGetElementPtr) + if !ok { + aptr, ok := ident.Value.(*ir.InstAlloca) + if !ok { + ctx.vars[a.Idents[i].Name] = &Variable{ + Name: a.Idents[i].Name, + Type: ident.Type, + Value: v, + } + } else { + ctx.NewStore(v, aptr) + } + } else { + ctx.NewStore(v, ptr) } - } else { - ctx.NewStore(val, aptr) } } else { - ctx.NewStore(val, ptr) + if len(idents) == 1 { + ptr, ok := idents[0].Value.(*ir.InstGetElementPtr) + if !ok { + aptr, ok := idents[0].Value.(*ir.InstAlloca) + if !ok { + ctx.vars[a.Idents[0].Name] = &Variable{ + Name: a.Idents[0].Name, + Type: idents[0].Type, + Value: val, + } + } else { + ctx.NewStore(val, aptr) + } + } else { + ctx.NewStore(val, ptr) + } + } else { + if _, ok := val.Type().(*types.StructType); !ok { + return posError(a.Right.Pos, "Cannot assign non-struct value to multiple variables") + } + + if len(val.Type().(*types.StructType).Fields) != len(idents) { + return posError(a.Right.Pos, "Unable to unpack %d values into %d variables", len(val.Type().(*types.StructType).Fields), len(idents)) + } + + for i, ident := range idents { + ptr, ok := ident.Value.(*ir.InstGetElementPtr) + if !ok { + aptr, ok := ident.Value.(*ir.InstAlloca) + if !ok { + ctx.vars[a.Idents[i].Name] = &Variable{ + Name: a.Idents[i].Name, + Type: ident.Type, + Value: ctx.NewExtractValue(val, uint64(i)), + } + } else { + ctx.NewStore(ctx.NewExtractValue(val, uint64(i)), aptr) + } + } else { + ctx.NewStore(ctx.NewExtractValue(val, uint64(i)), ptr) + } + } + } } - return a.Left.Name, val, nil + return nil } func (ctx *Context) compileFunctionDefinition(f *parser.FunctionDefinition) (Name string, ReturnType types.Type, Args []*ir.Param, err error) { var params []*ir.Param for _, arg := range f.Parameters { - params = append(params, ir.NewParam(arg.Name, ctx.StringToType(arg.Type))) + params = append(params, ir.NewParam(arg.Name, ctx.CFTypeToLLType(arg.Type))) + } + if f.Variadic != "" { + params = append(params, ir.NewParam(f.Variadic, types.I8Ptr)) } - retType := ctx.StringToType(f.ReturnType) + retType := ctx.CFMultiTypeToLLType(f.ReturnType) fn := ctx.Module.NewFunc(f.Name.Name, retType, params...) - fn.Sig.Variadic = f.Variadic + if f.Variadic != "" { + fn.Sig.Variadic = true + } block := fn.NewBlock("") nctx := NewContext(block, ctx.Compiler) ctx.SymbolTable[f.Name.Name] = fn @@ -226,7 +350,7 @@ func (ctx *Context) compileClassDefinition(c *parser.ClassDefinition) (Name stri ctx.Module.NewTypeDef(c.Name, classType) for _, s := range c.Body { if s.FieldDefinition != nil { - classType.Fields = append(classType.Fields, ctx.StringToType(s.FieldDefinition.Type)) + classType.Fields = append(classType.Fields, ctx.CFTypeToLLType(s.FieldDefinition.Type)) ctx.Compiler.StructFields[c.Name] = append(ctx.Compiler.StructFields[c.Name], s.FieldDefinition) } else if s.FunctionDefinition != nil { err := ctx.compileClassMethodDefinition(s.FunctionDefinition, c.Name, classType) @@ -243,10 +367,13 @@ func (ctx *Context) compileClassMethodDefinition(f *parser.FunctionDefinition, c var params []*ir.Param params = append(params, ir.NewParam("this", types.NewPointer(ctype))) for _, arg := range f.Parameters { - params = append(params, ir.NewParam(arg.Name, ctx.StringToType(arg.Type))) + params = append(params, ir.NewParam(arg.Name, ctx.CFTypeToLLType(arg.Type))) + } + if f.Variadic != "" { + params = append(params, ir.NewParam(f.Variadic, types.I8Ptr)) } - trimmed := strings.Trim(f.Name.String, "\"") + trimmed := strings.Trim(f.Name.Name, "\"") ms := "." + f.Name.Name if f.Name.Op { ms = ".op." + trimmed @@ -256,10 +383,12 @@ func (ctx *Context) compileClassMethodDefinition(f *parser.FunctionDefinition, c ms = ".set." + trimmed } - retType := ctx.StringToType(f.ReturnType) + retType := ctx.CFMultiTypeToLLType(f.ReturnType) fn := ctx.Module.NewFunc(cname+ms, retType, params...) - fn.Sig.Variadic = false + if f.Variadic != "" { + fn.Sig.Variadic = true + } block := fn.NewBlock("") nctx := NewContext(block, ctx.Compiler) ctx.SymbolTable[cname+ms] = fn @@ -429,15 +558,69 @@ func (ctx *Context) compileWhile(w *parser.While) error { return nil } +func (ctx *Context) compileUntil(u *parser.Until) error { + cond, err := ctx.compileExpression(u.Condition) + if err != nil { + return err + } + + loopB := ctx.Block.Parent.NewBlock("") + leaveB := ctx.Block.Parent.NewBlock("") + loopCtx := ctx.NewContext(loopB) + + ctx.NewCondBr(cond, leaveB, loopB) + loopCtx.fc.Leave = leaveB + loopCtx.fc.Continue = loopB + + for _, stmt := range u.Body { + err := loopCtx.compileStatement(stmt) + if err != nil { + return err + } + } + + cond, err = loopCtx.compileExpression(u.Condition) + if err != nil { + return err + } + loopCtx.NewCondBr(cond, leaveB, loopB) + ctx.Block = leaveB + + return nil +} + func (ctx *Context) compileReturn(r *parser.Return) error { - if r.Expression != nil { + if len(r.Expressions) == 1 { ctx.RequestedType = ctx.Block.Parent.Sig.RetType - val, err := ctx.compileExpression(r.Expression) + val, err := ctx.compileExpression(r.Expressions[0]) if err != nil { return posError(r.Pos, "Error compiling return expression: %s", err.Error()) } ctx.RequestedType = nil ctx.NewRet(val) + } else if len(r.Expressions) > 1 { + if _, ok := ctx.Block.Parent.Sig.RetType.(*types.StructType); !ok { + return posError(r.Pos, "Cannot return multiple values from a non-struct function") + } + + var vals []constant.Constant + for i, expr := range r.Expressions { + ctx.RequestedType = ctx.Block.Parent.Sig.RetType.(*types.StructType).Fields[i] + val, err := ctx.compileExpression(expr) + ctx.RequestedType = nil + if err != nil { + return posError(r.Pos, "Error compiling return expression: %s", err.Error()) + } + + constVal, ok := val.(constant.Constant) + if !ok { + return posError(r.Pos, "Return expression did not evaluate to a constant") + } + + vals = append(vals, constVal) + } + + ctx.NewRet(constant.NewStruct(ctx.Block.Parent.Sig.RetType.(*types.StructType), vals...)) } else { ctx.NewRet(nil) } diff --git a/src/lib/compiler/utils.go b/src/lib/compiler/utils.go index aa7070b..5a9766a 100644 --- a/src/lib/compiler/utils.go +++ b/src/lib/compiler/utils.go @@ -7,14 +7,103 @@ import ( "github.com/alecthomas/participle/v2/lexer" "github.com/fatih/color" + "github.com/llir/llvm/ir/constant" "github.com/llir/llvm/ir/types" "github.com/urfave/cli/v2" + "github.com/vyPal/CaffeineC/lib/parser" ) func posError(pos lexer.Position, message string, args ...interface{}) error { return cli.Exit(color.RedString("%s at %s:%d:%d", fmt.Sprintf(message, args...), pos.Filename, pos.Line, pos.Column), 1) } +func (ctx *Context) CFTypeToLLType(t *parser.Type) types.Type { + pointerCount := strings.Count(t.Ptr, "*") + var typ types.Type + + if t.Inner != nil { + typ = ctx.CFTypeToLLType(t.Inner) + } else { + if strings.HasPrefix(t.Name, "i") || strings.HasPrefix(t.Name, "u") { + size, _ := strconv.Atoi(t.Name[1:]) + typ = types.NewInt(uint64(size)) + } else { + switch t.Name { + case "void", "": + typ = types.Void + case "f16": + typ = types.Half + case "f32": + typ = types.Float + case "f64": + typ = types.Double + case "f128": + typ = types.FP128 + default: + for _, ty := range ctx.Module.TypeDefs { + if ty.Name() == t.Name { + typ = ty + break + } + } + } + } + + if typ == nil { + panic("Unknown type: " + t.Name) + } + } + + // If the type is a pointer, wrap it in the appropriate number of pointer types + for i := 0; i < pointerCount; i++ { + typ = types.NewPointer(typ) + } + + if t.Array != nil { + array, err := ctx.compileExpression(t.Array) + if err != nil { + panic(err) + } + + arraySize, ok := array.(*constant.Int) + if !ok { + panic("array size is not a constant integer") + } + + length := uint64(arraySize.X.Int64()) + + typ = types.NewArray(length, typ) + } + + return typ +} + +func (ctx *Context) CFMultiTypeToLLType(typeArr []*parser.Type) types.Type { + if len(typeArr) == 1 { + return ctx.CFTypeToLLType(typeArr[0]) + } else if len(typeArr) == 0 { + return types.Void + } + + var typs []types.Type + for _, t := range typeArr { + typs = append(typs, ctx.CFTypeToLLType(t)) + } + + return types.NewStruct(typs...) +} + +func isNumeric(t types.Type) bool { + switch t := t.(type) { + case *types.IntType, *types.FloatType: + return true + case *types.PointerType: + return isNumeric(t.ElemType) + default: + return false + } +} + func (ctx *Context) StringToType(name string) types.Type { pointerCount := strings.Count(name, "*") name = strings.TrimLeft(name, "*") diff --git a/src/lib/parser/grammar.go b/src/lib/parser/grammar.go index 551c66c..185b229 100644 --- a/src/lib/parser/grammar.go +++ b/src/lib/parser/grammar.go @@ -1,7 +1,7 @@ package parser import ( - "strconv" + "errors" "github.com/alecthomas/participle/v2/lexer" ) @@ -9,34 +9,27 @@ import ( type Bool bool func (b *Bool) Capture(values []string) error { - *b = values[0] == "true" - return nil -} - -type Duration struct { - Number float64 - Unit string -} - -func (d *Duration) Capture(values []string) error { - num, err := strconv.ParseFloat(values[0], 64) - if err != nil { - return err + switch values[0] { + case "true", "True": + *b = true + return nil + case "false", "False": + *b = false + return nil + default: + return errors.New(values[0] + " is not a valid boolean value") } - d.Number = num - d.Unit = values[1] - return nil } type Value struct { - Pos lexer.Position - Float *float64 `parser:" @('-'? Float)"` - Duration *Duration `parser:"| @Int @('h' | 'm' | 's' | 'ms' | 'us' | 'ns')"` - Int *int64 `parser:"| @('-'? Int)"` - HexInt *string `parser:"| @('0x' (Int | 'a' | 'b' | 'c' | 'd' | 'e' | 'f' | 'A' | 'B' | 'C' | 'D' | 'E' | 'F')+)"` - Bool *Bool `parser:"| @('true' | 'false')"` - String *string `parser:"| @String"` - Null bool `parser:"| @'null'"` + Pos lexer.Position + Array []*Expression `parser:"'[' ( @@ ( ',' @@ )* )? ']'"` + Float *float64 `parser:" @('-'? Float)"` + Int *int64 `parser:"| @('-'? Int)"` + HexInt *string `parser:"| @('-'? '0x' (Int | 'a' | 'b' | 'c' | 'd' | 'e' | 'f' | 'A' | 'B' | 'C' | 'D' | 'E' | 'F')+)"` + Bool *Bool `parser:"| @('true' | 'True' | 'false' | 'False')"` + String *string `parser:"| @String"` + Null bool `parser:"| @'null'"` } type Identifier struct { @@ -67,6 +60,7 @@ type FunctionCall struct { type Factor struct { Pos lexer.Position + Unpack bool `parser:"@'...'?"` Value *Value `parser:" @@"` FunctionCall *FunctionCall `parser:"| (?= ( Ident | String ) '(') @@"` BitCast *BitCast `parser:"| '(' @@"` @@ -78,56 +72,122 @@ type Factor struct { type BitCast struct { Pos lexer.Position Expr *Expression `parser:"@@ ')'"` - Type string `parser:"(':' @('*'* Ident))?"` + Type *Type `parser:"(':' @@)?"` } -type Term struct { +type Assignment struct { + Pos lexer.Position + Idents []*Identifier `parser:"@@ ( ',' @@ )*"` + Op string `parser:"@(('+'|'-'|'*'|'/'|'%'|'&'|'|'|'^'|'<''<'|'>''>'|'>''>''>'|'?''?')?'=')"` + Right *Expression `parser:"@@"` +} + +type Expression struct { + Pos lexer.Position + Condition *LogicalOr `parser:"@@"` + True *Expression `parser:"('?' @@ ':')?"` + False *Expression `parser:"@@?"` +} + +type LogicalOr struct { Pos lexer.Position - Left *Factor `parser:"@@"` - Right []*OpTerm `parser:"@@*"` + Left *LogicalAnd `parser:"@@"` + Op string `parser:"@( '|' '|' | 'or' )?"` + Right []*LogicalOr `parser:"@@?"` } -type OpTerm struct { - Pos lexer.Position - Op string `parser:"@( '*' | '/' | '%' )"` - Term *Factor `parser:"@@"` +type LogicalAnd struct { + Pos lexer.Position + Left *BitwiseOr `parser:"@@"` + Op string `parser:"@( '&' '&' | 'and' )?"` + Right []*LogicalAnd `parser:"@@?"` } -type Comparison struct { +type BitwiseOr struct { Pos lexer.Position - Left *Term `parser:"@@"` - Right []*OpComparison `parser:"@@*"` + Left *BitwiseXor `parser:"@@"` + Op string `parser:"@'|'?"` + Right []*BitwiseOr `parser:"@@?"` } -type OpComparison struct { - Pos lexer.Position - Op string `parser:"@( ('=' '=') | ( '<' '=' ) | '<' | ( '>' '=' ) |'>' | ('!' '=') )"` - Comparison *Term `parser:"@@"` +type BitwiseXor struct { + Pos lexer.Position + Left *BitwiseAnd `parser:"@@"` + Op string `parser:"@'^'?"` + Right []*BitwiseXor `parser:"@@?"` } -type Expression struct { +type BitwiseAnd struct { Pos lexer.Position - Left *Comparison `parser:"@@"` - Right []*OpExpression `parser:"@@*"` + Left *Equality `parser:"@@"` + Op string `parser:"@'&'?"` + Right []*BitwiseAnd `parser:"@@?"` } -type OpExpression struct { - Pos lexer.Position - Op string `parser:"@( '+' | '-' | '&' '&' | '|' '|' )"` - Expression *Comparison `parser:"@@"` +type Equality struct { + Pos lexer.Position + Left *Relational `parser:"@@"` + Op string `parser:"@( '=' '=' | '!' '=' )?"` + Right []*Equality `parser:"@@?"` } -type Assignment struct { +type Relational struct { + Pos lexer.Position + Left *Shift `parser:"@@"` + Op string `parser:"@( '<' '=' | '>' '=' | '<' | '>' )?"` + Right []*Relational `parser:"@@?"` +} + +type Shift struct { + Pos lexer.Position + Left *Additive `parser:"@@"` + Op string `parser:"@( '<' '<' | '>' '>' | '>' '>' '>' )?"` + Right []*Shift `parser:"@@?"` +} + +type Additive struct { + Pos lexer.Position + Left *Multiplicative `parser:"@@"` + Op string `parser:"@( '+' | '-' )?"` + Right []*Additive `parser:"@@?"` +} + +type Multiplicative struct { + Pos lexer.Position + Left *LogicalNot `parser:"@@"` + Op string `parser:"@( '*' | '/' | '%' )?"` + Right []*Multiplicative `parser:"@@?"` +} + +type LogicalNot struct { + Pos lexer.Position + Op string `parser:"@'!'?"` + Right *BitwiseNot `parser:"@@"` +} + +type BitwiseNot struct { Pos lexer.Position - Left *Identifier `parser:"@@"` - Right *Expression `parser:"'=' @@"` + Op string `parser:"@'~'?"` + Right *PrefixAdditive `parser:"@@"` +} + +type PrefixAdditive struct { + Pos lexer.Position + Op string `parser:"@('+' '+' | '-' '-')?"` + Right *PostfixAdditive `parser:"@@"` +} + +type PostfixAdditive struct { + Pos lexer.Position + Left *Factor `parser:"@@"` + Op string `parser:"@('+' '+' | '-' '-')?"` } type VariableDefinition struct { Pos lexer.Position - Constant bool `parser:"'const'?"` - Name string `parser:"'var' @Ident"` - Type string `parser:"':' @('*'* Ident)"` + Constant string `parser:"@('const' | 'var')"` + Name string `parser:"@Ident"` + Type *Type `parser:"':' @@"` Assignment *Expression `parser:"( '=' @@ )?"` } @@ -135,32 +195,31 @@ type FieldDefinition struct { Pos lexer.Position Private bool `parser:"@'private'?"` Name string `parser:"@Ident"` - Type string `parser:"':' @('*'* Ident) ';'"` + Type *Type `parser:"':' @@ ';'"` } type ArgumentDefinition struct { Pos lexer.Position Name string `parser:"@Ident"` - Type string `parser:"':' @('*'* Ident)"` + Type *Type `parser:"':' @@"` } type FuncName struct { - Dummy string `parser:"'func'"` - Op bool `parser:"@'op'?"` - Get bool `parser:"@'get'?"` - Set bool `parser:"@'set'?"` - Name string `parser:"@Ident?"` - String string `parser:"@String?"` + Dummy string `parser:"'func'"` + Op bool `parser:"@'op'?"` + Get bool `parser:"@'get'?"` + Set bool `parser:"@'set'?"` + Name string `parser:"@(Ident | String)"` } type FunctionDefinition struct { Pos lexer.Position Private bool `parser:"@'private'?"` Static bool `parser:"@'static'?"` - Variadic bool `parser:"@'vararg'?"` Name FuncName `parser:"@@"` - Parameters []*ArgumentDefinition `parser:"'(' ( @@ ( ',' @@ )* )? ')'"` - ReturnType string `parser:"( ':' @('*'* Ident) )?"` + Parameters []*ArgumentDefinition `parser:"'(' ( @@ ( ',' @@ )* )?"` + Variadic string `parser:"(',' '.' '.' '.' @Ident)?"` + ReturnType []*Type `parser:"')' ( ':' @@ ( ',' @@ )* )?"` Body []*Statement `parser:"'{' @@* '}'"` } @@ -204,21 +263,62 @@ type While struct { Body []*Statement `parser:"'{' @@* '}'"` } +type Until struct { + Pos lexer.Position + Condition *Expression `parser:"'(' @@ ')'"` + Body []*Statement `parser:"'{' @@* '}'"` +} + +type Switch struct { + Pos lexer.Position + Condition *Expression `parser:"'(' @@ ')'"` + Cases []*Case `parser:"'{' @@*"` + Default []*Statement `parser:"('default' ':' @@*)? '}'"` +} + +type Case struct { + Pos lexer.Position + Values []*Expression `parser:"'case' @@ ( ',' @@ )* ':'"` + Body []*Statement `parser:"@@*"` +} + type Return struct { - Pos lexer.Position - Expression *Expression `parser:"@@? ';'"` + Pos lexer.Position + Expressions []*Expression `parser:"@@? ( ',' @@ )* ';'"` } type ExternalFunctionDefinition struct { Pos lexer.Position - Variadic bool `parser:"@'vararg'?"` Name string `parser:"'func' @( Ident | String )"` - Parameters []*ArgumentDefinition `parser:"'(' ( @@ ( ',' @@ )* )? ')'"` - ReturnType string `parser:"( ':' @('*'* Ident) )?"` + Parameters []*ArgumentDefinition `parser:"'(' ( @@ ( ',' @@ )* )?"` + Variadic bool `parser:"@(',' '.' '.' '.')?"` + ReturnType []*Type `parser:"')' ( ':' @@ ( ',' @@ )* )?"` +} + +type TryCatch struct { + Pos lexer.Position + Try []*Statement `parser:"'try' '{' @@* '}'"` + Catch *Catch `parser:"'catch' @@"` + Final []*Statement `parser:"('finally' '{' @@* '}')?"` +} + +type Catch struct { + Pos lexer.Position + Name string `parser:"@Ident"` + Body []*Statement `parser:"'{' @@* '}'"` +} + +type Type struct { + Pos lexer.Position + Array *Expression `parser:"('[' @@ ']')?"` + Ptr string `parser:"@'*'*"` + Name string `parser:"@Ident"` + Inner *Type `parser:"| @@"` } type Import struct { - Package string `parser:"@String ';'"` + Package string `parser:"@String"` + Alias string `parser:"('as' @Ident)? ';'"` } type FromImport struct { @@ -239,17 +339,20 @@ type Symbol struct { type Statement struct { Pos lexer.Position - VariableDefinition *VariableDefinition `parser:"(?= 'const'? 'var' Ident) @@? (';' | '\\n')?"` - Assignment *Assignment `parser:"| (?= Ident ( '[' ~']' ']' )? ( '.' Ident ( '[' ~']' ']' )? )* '=') @@? (';' | '\\n')?"` + VariableDefinition *VariableDefinition `parser:"(?= ('const' | 'var') Ident) @@? (';' | '\\n')?"` + Assignment *Assignment `parser:"| (?= Ident ('['~']'']')?('.'Ident('['~']'']')?)*(','Ident('['~']'']')?('.'Ident('['~']'']')?)*)*('+'|'-'|'*'|'/'|'%'|'&'|'|'|'^'|'<''<'|'>''>'|'>''>''>'|'?''?')?'=')@@?(';' | '\\n')?"` External *ExternalFunctionDefinition `parser:"| 'extern' @@ ';'"` Export *Statement `parser:"| 'export' @@"` FunctionDefinition *FunctionDefinition `parser:"| (?= 'private'? 'static'? 'func') @@?"` + TryCatch *TryCatch `parser:"| 'try' @@"` + Switch *Switch `parser:"| 'switch' @@"` ClassDefinition *ClassDefinition `parser:"| 'class' @@?"` If *If `parser:"| 'if' @@?"` For *For `parser:"| 'for' @@?"` While *While `parser:"| 'while' @@?"` + Until *Until `parser:"| 'until' @@?"` Return *Return `parser:"| 'return' @@?"` - FieldDefinition *FieldDefinition `parser:"| (?= 'private'? Ident ':' '*'* Ident) @@?"` + FieldDefinition *FieldDefinition `parser:"| (?= 'private'? Ident ':' ('[' ~']' ']')? '*'* Ident) @@?"` Import *Import `parser:"| 'import' @@?"` FromImportMultiple *FromImportMultiple `parser:"| (?= 'from' String 'import' '{') @@?"` FromImport *FromImport `parser:"| (?= 'from' String 'import') @@?"` diff --git a/src/lib/parser/parser.go b/src/lib/parser/parser.go index 3f702f0..552675e 100644 --- a/src/lib/parser/parser.go +++ b/src/lib/parser/parser.go @@ -8,8 +8,17 @@ import ( ) var parser *participle.Parser[Program] +var parsed map[string]*Program func ParseFile(filename string) *Program { + if parsed == nil { + parsed = make(map[string]*Program) + } + + if parsed[filename] != nil { + return parsed[filename] + } + if parser == nil { parser = participle.MustBuild[Program](participle.Lexer(cflex.DefaultDefinition)) } @@ -23,6 +32,7 @@ func ParseFile(filename string) *Program { if err != nil { panic(err) } + parsed[filename] = ast return ast } diff --git a/src/lib/project/cfconfig.go b/src/lib/project/cfconfig.go index 4b4b2a0..c7c2995 100644 --- a/src/lib/project/cfconfig.go +++ b/src/lib/project/cfconfig.go @@ -17,6 +17,16 @@ type CfConf struct { Dependencies []CFConfDependency `yaml:"dependencies"` Author string `yaml:"author"` License string `yaml:"license"` + Scripts map[string]string `yaml:"scripts"` + Compiler CFConfCompiler `yaml:"compiler"` +} + +type CFConfCompiler struct { + Target string `yaml:"target"` + OptimizationLevel int `yaml:"optimization"` + ClangFlags string `yaml:"clangFlags"` + GCCFlags string `yaml:"gccFlags"` + LLCFlags string `yaml:"llcFlags"` } type CFConfDependency struct { @@ -25,8 +35,11 @@ type CFConfDependency struct { Identifier string `yaml:"identifier"` } -func (c *CfConf) CreateDefault() { - c.Name = "NewProject" +func (c *CfConf) CreateDefault(name string) { + if name == "." { + name = "NewProject" + } + c.Name = name c.Description = "A new CaffeineC project" c.Version = "1.0.0" c.Main = "src/main.cffc" diff --git a/src/projects.go b/src/projects.go index 228d523..a5e160e 100644 --- a/src/projects.go +++ b/src/projects.go @@ -450,7 +450,7 @@ func initProject(c *cli.Context) error { if util.PromptYN("Use default configuration?", false) { conf := project.CfConf{} - conf.CreateDefault() + conf.CreateDefault(rootDir) err := conf.Save(path.Join(rootDir, "cfconf.yaml"), false) if err != nil {