From 63ccacba2d4d3806fc34e5f881729d3abc524ee3 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Wed, 20 Mar 2024 19:26:48 +0100 Subject: [PATCH] Improve compiler & validator, Update lexer --- backend_wat.go | 60 ++++++++++++++++- example/add.lang | 8 ++- lexer.go | 7 +- main.go | 3 +- parser.go | 164 +++++++++++++++++++++++++++++++++-------------- validator.go | 138 ++++++++++++++++++++++++++++----------- 6 files changed, 282 insertions(+), 98 deletions(-) diff --git a/backend_wat.go b/backend_wat.go index 165313c..d275b16 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -2,6 +2,7 @@ package main import ( "errors" + "log" "strconv" ) @@ -136,9 +137,11 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { } func compileExpressionWAT(expr Expression, block Block) (string, error) { + var err error + switch expr.Type { case Expression_Assignment: - + // TODO case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -165,6 +168,8 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { case Expression_Arithmetic: arith := expr.Value.(ArithmeticExpression) + log.Printf("%+#v", arith) + // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings exprType := expr.ValueType.Value.(PrimitiveType) @@ -212,9 +217,51 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil case Expression_Tuple: + tuple := expr.Value.(TupleExpression) + + wat := "" + for _, member := range tuple.Members { + memberWAT, err := compileExpressionWAT(member, block) + if err != nil { + return "", err + } + + wat += memberWAT + } + + return wat, nil + case Expression_FunctionCall: + fc := expr.Value.(FunctionCallExpression) + + wat := "" + if fc.Parameters != nil { + wat, err = compileExpressionWAT(*fc.Parameters, block) + if err != nil { + return "", err + } + } + + return wat + "call $" + fc.Function + "\n", nil + case Expression_Negate: + neg := expr.Value.(NegateExpression) + exprType := expr.ValueType.Value.(PrimitiveType) + + wat, err := compileExpressionWAT(neg.Value, block) + if err != nil { + return "", err + } + + watType := getPrimitiveWATType(exprType) + if isSignedInt(exprType) || isUnsignedInt(exprType) { + return watType + ".const 0\n" + wat + watType + ".sub\n", nil + } + + if isFloatingPoint(exprType) { + return watType + ".neg\n", nil + } } - return "", nil + panic("expr not implemented") } func compileStatementWAT(stmt Statement, block Block) (string, error) { @@ -286,7 +333,14 @@ func compileFunctionWAT(function ParsedFunction) (string, error) { } // TODO: tuples - funcWAT += "\t(result " + getWATType(function.ReturnType) + ")\n" + returnTypes := []Type{function.ReturnType} + if function.ReturnType.Type == Type_Tuple { + returnTypes = function.ReturnType.Value.(TupleType).Types + } + + for _, t := range returnTypes { + funcWAT += "\t(result " + getWATType(t) + ")\n" + } for _, local := range function.Locals { if local.IsParameter { diff --git a/example/add.lang b/example/add.lang index 439547f..34e1c31 100644 --- a/example/add.lang +++ b/example/add.lang @@ -1,3 +1,7 @@ -u64 add(u8 a, u64 b) { - return a * a + b * b; +u64 add(u8 a, u8 b) { + return add(a - 1u8, a); +} + +(u8, u8) doNothing(u8 a, u8 b) { + return a, b; } diff --git a/lexer.go b/lexer.go index 9f91f74..65cc5c2 100644 --- a/lexer.go +++ b/lexer.go @@ -212,8 +212,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { } runes := []rune(token) - startsWithMinus := runes[0] == '-' - if startsWithMinus || unicode.IsDigit([]rune(token)[0]) { + if unicode.IsDigit([]rune(token)[0]) { // TODO: hexadecimal/binary/octal constants var numberType PrimitiveType = InvalidValue @@ -230,10 +229,8 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { if numberType == InvalidValue { if containsDot { numberType = Primitive_F64 - } else if startsWithMinus { - numberType = Primitive_I64 } else { - numberType = Primitive_U64 + numberType = Primitive_I64 } } diff --git a/main.go b/main.go index 7598427..df57c55 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,8 @@ func main() { log.Printf("Parsed:\n%+#v\n\n", parsed) - errors := validator(parsed) + validator := Validator{file: parsed} + errors := validator.validate() if len(errors) != 0 { for _, err = range errors { if c, ok := err.(CompilerError); ok { diff --git a/parser.go b/parser.go index e5b8b79..62947d3 100644 --- a/parser.go +++ b/parser.go @@ -72,6 +72,8 @@ const ( Expression_VariableReference Expression_Arithmetic Expression_Tuple + Expression_FunctionCall + Expression_Negate ) type Expression struct { @@ -113,6 +115,15 @@ type TupleExpression struct { Members []Expression } +type FunctionCallExpression struct { + Function string + Parameters *Expression +} + +type NegateExpression struct { + Value Expression +} + type Local struct { Name string Type Type @@ -413,8 +424,51 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { } } + if token.Type == Type_Operator { + op := token.Value.(Operator) + if op == Operator_Minus || op == Operator_Plus { + pCopy.nextToken() + expr, err := pCopy.tryUnaryExpression() + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + if op == Operator_Minus { + expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}} + } + + return expr, nil + } + } + if token.Type == Type_Identifier { pCopy.nextToken() + + next, err := pCopy.trySeparator(Separator_OpenParen) + if err != nil { + return nil, err + } + + if next != nil { + // Function call + params, err := pCopy.tryTupleExpression() + if err != nil { + return nil, err + } + + _, err = pCopy.expectSeparator(Separator_CloseParen) + if err != nil { + return nil, err + } + + *p = pCopy + return &Expression{Type: Expression_FunctionCall, Value: FunctionCallExpression{Function: token.Value.(string), Parameters: params}}, nil + } + *p = pCopy return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}}, nil } @@ -428,41 +482,47 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { return nil, err } - op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo) - if err != nil { - return nil, err + if left == nil { + return nil, nil } - if op == nil { - return left, nil - } + for { + op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo) + if err != nil { + return nil, err + } - right, err := p.tryUnaryExpression() - if err != nil { - return nil, err - } + if op == nil { + return left, nil + } - if right == nil { - return nil, p.error("expected expression") - } + right, err := p.tryUnaryExpression() + if err != nil { + return nil, err + } - var operation ArithmeticOperation - switch *op { - case Operator_Multiply: - operation = Arithmetic_Mul - case Operator_Divide: - operation = Arithmetic_Div - case Operator_Plus: - operation = Arithmetic_Add - case Operator_Minus: - operation = Arithmetic_Sub - case Operator_Modulo: - fallthrough - default: - operation = Arithmetic_Mod - } + if right == nil { + return nil, p.error("expected expression") + } - return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil + var operation ArithmeticOperation + switch *op { + case Operator_Multiply: + operation = Arithmetic_Mul + case Operator_Divide: + operation = Arithmetic_Div + case Operator_Plus: + operation = Arithmetic_Add + case Operator_Minus: + operation = Arithmetic_Sub + case Operator_Modulo: + fallthrough + default: + operation = Arithmetic_Mod + } + + left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}} + } } func (p *Parser) tryAdditiveExpression() (*Expression, error) { @@ -471,32 +531,38 @@ func (p *Parser) tryAdditiveExpression() (*Expression, error) { return nil, err } - op, err := p.tryOperator(Operator_Plus, Operator_Minus) - if err != nil { - return nil, err + if left == nil { + return nil, nil } - if op == nil { - return left, nil - } + for { + op, err := p.tryOperator(Operator_Plus, Operator_Minus) + if err != nil { + return nil, err + } - right, err := p.tryMultiplicativeExpression() - if err != nil { - return nil, err - } + if op == nil { + return left, nil + } - if right == nil { - return nil, p.error("expected expression") - } + right, err := p.tryMultiplicativeExpression() + if err != nil { + return nil, err + } - var operation ArithmeticOperation - if *op == Operator_Plus { - operation = Arithmetic_Add - } else { - operation = Arithmetic_Sub - } + if right == nil { + return nil, p.error("expected expression") + } - return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil + var operation ArithmeticOperation + if *op == Operator_Plus { + operation = Arithmetic_Add + } else { + operation = Arithmetic_Sub + } + + left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}} + } } func (p *Parser) tryArithmeticExpression() (*Expression, error) { diff --git a/validator.go b/validator.go index 328158d..c1cc611 100644 --- a/validator.go +++ b/validator.go @@ -2,19 +2,22 @@ package main import ( "errors" + "strconv" ) -func createError(message string) error { - // TODO: pass token and get actual token position - return errors.New(message) +type Validator struct { + file *ParsedFile } -func validateImport(imp *Import) []error { - // TODO - return nil +func isTypeExpandableTo(from Type, to Type) bool { + if from.Type == Type_Primitive && to.Type == Type_Primitive { + return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) + } + + panic("not implemented") } -func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { +func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true } @@ -46,25 +49,35 @@ func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { return false } -func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { +func (v *Validator) createError(message string) error { + // TODO: pass token and get actual token position + return errors.New(message) +} + +func (v *Validator) validateImport(imp *Import) []error { + // TODO + return nil +} + +func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { - return InvalidValue, createError("bool type cannot be used in arithmetic expressions") + return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions") } - if isTypeExpandableTo(left, right) { + if isPrimitiveTypeExpandableTo(left, right) { return right, nil } - if isTypeExpandableTo(right, left) { + if isPrimitiveTypeExpandableTo(right, left) { return left, nil } // TODO: boolean expressions etc. - return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error + return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error } -func validateExpression(expr *Expression, block *Block) []error { +func (v *Validator) validateExpression(expr *Expression, block *Block) []error { var errors []error switch expr.Type { @@ -73,11 +86,11 @@ func validateExpression(expr *Expression, block *Block) []error { var local Local var ok bool if local, ok = block.Locals[assignment.Variable]; !ok { - errors = append(errors, createError("Assignment to undeclared variable "+assignment.Variable)) + errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable)) return errors } - valueErrors := validateExpression(&assignment.Value, block) + valueErrors := v.validateExpression(&assignment.Value, block) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors @@ -100,31 +113,29 @@ func validateExpression(expr *Expression, block *Block) []error { var local Local var ok bool if local, ok = block.Locals[reference.Variable]; !ok { - errors = append(errors, createError("Reference to undeclared variable "+reference.Variable)) + errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable)) return errors } expr.ValueType = local.Type - expr.Value = reference case Expression_Arithmetic: arithmethic := expr.Value.(ArithmeticExpression) - errors = append(errors, validateExpression(&arithmethic.Left, block)...) - errors = append(errors, validateExpression(&arithmethic.Right, block)...) + errors = append(errors, v.validateExpression(&arithmethic.Left, block)...) + errors = append(errors, v.validateExpression(&arithmethic.Right, block)...) if len(errors) != 0 { return errors } - // TODO: validate types compatible and determine result type if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { - errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type")) + errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type")) return errors } leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) - result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation) + result, err := v.getArithmeticResultType(leftType, rightType, arithmethic.Operation) if err != nil { errors = append(errors, err) return errors @@ -139,7 +150,7 @@ func validateExpression(expr *Expression, block *Block) []error { for i := range tuple.Members { member := &tuple.Members[i] - memberErrors := validateExpression(member, block) + memberErrors := v.validateExpression(member, block) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue @@ -154,12 +165,63 @@ func validateExpression(expr *Expression, block *Block) []error { expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple + case Expression_FunctionCall: + fc := expr.Value.(FunctionCallExpression) + + var calledFunc *ParsedFunction = nil + for _, f := range v.file.Functions { + if f.Name == fc.Function { + calledFunc = &f + break + } + } + + if calledFunc == nil { + errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'")) + return errors + } + + if fc.Parameters != nil { + errors = append(errors, v.validateExpression(fc.Parameters, block)...) + + params := []Expression{*fc.Parameters} + if fc.Parameters.Type == Expression_Tuple { + params = fc.Parameters.Value.(TupleExpression).Members + } + + if len(params) != len(calledFunc.Parameters) { + errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)))) + } + + for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ { + typeGiven := params[i] + typeExpected := calledFunc.Parameters[i] + if !isTypeExpandableTo(typeGiven.ValueType, typeExpected.Type) { + errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i))) + } + } + } + + // TODO: get function and validate using return type + expr.ValueType = calledFunc.ReturnType + expr.Value = fc + case Expression_Negate: + neg := expr.Value.(NegateExpression) + + errors = append(errors, v.validateExpression(&neg.Value, block)...) + + if neg.Value.ValueType.Type != Type_Primitive { + errors = append(errors, v.createError("cannot negate non-number types")) + } + + expr.ValueType = neg.Value.ValueType + expr.Value = neg } return errors } -func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { +func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { var errors []error // TODO: support references to variables in parent block @@ -167,25 +229,25 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [ switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) - errors = append(errors, validateExpression(&expression.Expression, block)...) + errors = append(errors, v.validateExpression(&expression.Expression, block)...) *stmt = Statement{Type: Statement_Expression, Value: expression} case Statement_Block: block := stmt.Value.(BlockStatement) - errors = append(errors, validateBlock(&block.Block, functionLocals)...) + errors = append(errors, v.validateBlock(&block.Block, functionLocals)...) *stmt = Statement{Type: Statement_Block, Value: block} case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { - errors = append(errors, validateExpression(ret.Value, block)...) + errors = append(errors, v.validateExpression(ret.Value, block)...) } case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { - errors = append(errors, validateExpression(dlv.Initializer, block)...) + errors = append(errors, v.validateExpression(dlv.Initializer, block)...) } if _, ok := block.Locals[dlv.Variable]; ok { - errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'")) + errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'")) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} @@ -198,7 +260,7 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [ return errors } -func validateBlock(block *Block, functionLocals *[]Local) []error { +func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { @@ -207,13 +269,13 @@ func validateBlock(block *Block, functionLocals *[]Local) []error { for i := range block.Statements { stmt := &block.Statements[i] - errors = append(errors, validateStatement(stmt, block, functionLocals)...) + errors = append(errors, v.validateStatement(stmt, block, functionLocals)...) } return errors } -func validateFunction(function *ParsedFunction) []error { +func (v *Validator) validateFunction(function *ParsedFunction) []error { var errors []error var locals []Local @@ -226,22 +288,22 @@ func validateFunction(function *ParsedFunction) []error { body.Locals[param.Name] = local } - errors = append(errors, validateBlock(body, &locals)...) + errors = append(errors, v.validateBlock(body, &locals)...) function.Locals = locals return errors } -func validator(file *ParsedFile) []error { +func (v *Validator) validate() []error { var errors []error - for i := range file.Imports { - errors = append(errors, validateImport(&file.Imports[i])...) + for i := range v.file.Imports { + errors = append(errors, v.validateImport(&v.file.Imports[i])...) } - for i := range file.Functions { - errors = append(errors, validateFunction(&file.Functions[i])...) + for i := range v.file.Functions { + errors = append(errors, v.validateFunction(&v.file.Functions[i])...) } return errors