diff --git a/backend_wat.go b/backend_wat.go index 19c2e63..d72395a 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -2,7 +2,6 @@ package main import ( "errors" - "log" "strconv" "unicode" ) @@ -181,29 +180,28 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil case Expression_Binary: - arith := expr.Value.(BinaryExpression) - - log.Printf("%+#v", arith) + binary := expr.Value.(BinaryExpression) // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings + resultType := binary.ResultType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType) - watLeft, err := compileExpressionWAT(arith.Left, block) + watLeft, err := compileExpressionWAT(binary.Left, block) if err != nil { return "", err } - castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType) + castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType) if err != nil { return "", err } - watRight, err := compileExpressionWAT(arith.Right, block) + watRight, err := compileExpressionWAT(binary.Right, block) if err != nil { return "", err } - castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType) + castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType) if err != nil { return "", err } @@ -211,23 +209,35 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { op := "" suffix := "" - if isUnsignedInt(exprType) { + if isUnsignedInt(resultType) { suffix = "u" } else { suffix = "s" } - switch arith.Operation { + switch binary.Operation { case Operation_Add: - op = getPrimitiveWATType(exprType) + ".add\n" + op = getPrimitiveWATType(resultType) + ".add\n" case Operation_Sub: - op = getPrimitiveWATType(exprType) + ".sub\n" + op = getPrimitiveWATType(resultType) + ".sub\n" case Operation_Mul: - op = getPrimitiveWATType(exprType) + ".mul\n" + op = getPrimitiveWATType(resultType) + ".mul\n" case Operation_Div: - op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n" + op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n" case Operation_Mod: - op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n" + op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n" + case Operation_Greater: + op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n" + case Operation_Less: + op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n" + case Operation_GreaterEquals: + op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n" + case Operation_LessEquals: + op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n" + case Operation_NotEquals: + op = getPrimitiveWATType(resultType) + ".ne\n" + case Operation_Equals: + op = getPrimitiveWATType(resultType) + ".eq\n" default: panic("operation not implemented") } @@ -306,6 +316,8 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) { return "", err } + // TODO: upcast to return type for non-primitive types + return wat + "return\n", nil case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) @@ -319,6 +331,53 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) { } return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil + case Statement_If: + ifS := stmt.Value.(IfStatement) + + conditionWAT, err := compileExpressionWAT(ifS.Condition, block) + if err != nil { + return "", err + } + + condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock) + if err != nil { + return "", err + } + + wat := "" + + if ifS.ElseBlock != nil { + wat += "block\n" + } + + // condition + wat += "block\n" + + wat += conditionWAT + wat += "i32.eqz\n" // logical not + wat += "br_if 0\n" + + // condition is true + wat += condBlockWAT + + if ifS.ElseBlock != nil { + wat += "br 1\n" // jump over else block + } + + wat += "end\n" + + if ifS.ElseBlock != nil { + // condition is false + elseWAT, err := compileBlockWAT(*ifS.ElseBlock) + if err != nil { + return "", err + } + + wat += elseWAT + wat += "end\n" + } + + return wat, nil } panic("stmt not implemented") diff --git a/example/add.lang b/example/add.lang index a951ef5..b7d92f6 100644 --- a/example/add.lang +++ b/example/add.lang @@ -2,14 +2,6 @@ u64 add(u8 a, u8 b) { return add(a - 1u8, a); } -u64 add2(u64 a, u64 b) { - if(a == b) { - return 0; - } - - return a + b; -} - void a() { } @@ -17,3 +9,19 @@ void a() { (u8, u8) doNothing(u8 a, u8 b) { return a, b; } + +u64 doStuff(u64 a, u64 b) { + if(a > b) { + return 1u64; + } + + return 2u64; +} + +u64 fib(u64 n) { + if(n <= 1u64) { + return 1u64; + } + + return fib(n - 1u64) + fib(n - 2u64); +} diff --git a/parser.go b/parser.go index 82a4b2c..745453f 100644 --- a/parser.go +++ b/parser.go @@ -125,7 +125,7 @@ type BinaryExpression struct { Operation Operation Left Expression Right Expression - ResultType *Type + ResultType *Type // Type to expand the operands to before performing the operation } type TupleExpression struct { diff --git a/validator.go b/validator.go index e8b2a9f..d1fee57 100644 --- a/validator.go +++ b/validator.go @@ -1,18 +1,45 @@ package main import ( + "log" "strconv" ) type Validator struct { file *ParsedFile + + currentBlock *Block + currentFunction *ParsedFunction } func isTypeExpandableTo(from Type, to Type) bool { - if from.Type == Type_Primitive && to.Type == Type_Primitive { + if from.Type != to.Type { + // cannot convert between primitive, named, array and tuple types + return false + } + + if from.Type == Type_Primitive { return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) } + if from.Type == Type_Tuple { + fromT := from.Value.(TupleType) + toT := to.Value.(TupleType) + + if len(fromT.Types) != len(toT.Types) { + return false + } + + for i := 0; i < len(fromT.Types); i++ { + if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) { + return false + } + } + + return true + } + + log.Printf("%+#v %+#v", from, to) panic("not implemented") } @@ -76,7 +103,7 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error } -func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error { +func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error { var errors []error switch expr.Type { @@ -84,12 +111,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B assignment := expr.Value.(AssignmentExpression) var local Local var ok bool - if local, ok = block.Locals[assignment.Variable]; !ok { + if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok { errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors } - valueErrors := v.validateExpression(&assignment.Value, block) + valueErrors := v.validateExpression(&assignment.Value) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors @@ -111,7 +138,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B reference := expr.Value.(VariableReferenceExpression) var local Local var ok bool - if local, ok = block.Locals[reference.Variable]; !ok { + if local, ok = v.currentBlock.Locals[reference.Variable]; !ok { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors } @@ -120,8 +147,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B case Expression_Binary: binary := expr.Value.(BinaryExpression) - errors = append(errors, v.validateExpression(&binary.Left, block)...) - errors = append(errors, v.validateExpression(&binary.Right, block)...) + errors = append(errors, v.validateExpression(&binary.Left)...) + errors = append(errors, v.validateExpression(&binary.Right)...) if len(errors) != 0 { return errors @@ -180,7 +207,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B for i := range tuple.Members { member := &tuple.Members[i] - memberErrors := v.validateExpression(member, block) + memberErrors := v.validateExpression(member) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue @@ -212,7 +239,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B } if fc.Parameters != nil { - paramsErrors := v.validateExpression(fc.Parameters, block) + paramsErrors := v.validateExpression(fc.Parameters) if len(paramsErrors) != 0 { errors = append(errors, paramsErrors...) return errors @@ -242,7 +269,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B case Expression_Negate: neg := expr.Value.(NegateExpression) - valErrors := v.validateExpression(&neg.Value, block) + valErrors := v.validateExpression(&neg.Value) if len(valErrors) != 0 { errors = append(errors, valErrors...) return errors @@ -261,8 +288,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B return errors } -func (v *Validator) validateExpression(expr *Expression, block *Block) []error { - errors := v.validatePotentiallyVoidExpression(expr, block) +func (v *Validator) validateExpression(expr *Expression) []error { + errors := v.validatePotentiallyVoidExpression(expr) if len(errors) != 0 { return errors } @@ -274,7 +301,7 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { return errors } -func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { +func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error { var errors []error // TODO: support references to variables in parent block @@ -282,7 +309,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) - errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...) + errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...) stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) @@ -291,20 +318,35 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { - errors = append(errors, v.validateExpression(ret.Value, block)...) + if v.currentFunction.ReturnType == nil { + errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) + return errors + } + + errors = append(errors, v.validateExpression(ret.Value)...) + + if len(errors) != 0 { + return errors + } + + if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) { + errors = append(errors, v.createError("expression type does not match function return type", ret.Value.Position)) + } + } else if v.currentFunction.ReturnType != nil { + errors = append(errors, v.createError("missing return value", stmt.Position)) } case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { - errors = append(errors, v.validateExpression(dlv.Initializer, block)...) + errors = append(errors, v.validateExpression(dlv.Initializer)...) } - if _, ok := block.Locals[dlv.Variable]; ok { + if _, ok := v.currentBlock.Locals[dlv.Variable]; ok { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} - block.Locals[dlv.Variable] = local + v.currentBlock.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) // TODO: check if assignment of initializer is correct @@ -312,7 +354,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc case Statement_If: ifS := stmt.Value.(IfStatement) - errors = append(errors, v.validateExpression(&ifS.Condition, block)...) + errors = append(errors, v.validateExpression(&ifS.Condition)...) errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...) if ifS.ElseBlock != nil { @@ -343,8 +385,9 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error } for i := range block.Statements { + v.currentBlock = block stmt := &block.Statements[i] - errors = append(errors, v.validateStatement(stmt, block, functionLocals)...) + errors = append(errors, v.validateStatement(stmt, functionLocals)...) } return errors @@ -355,6 +398,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { var locals []Local + v.currentFunction = function + body := &function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters {