From a7ec08b379bbd5d9a7ce96279f648c04aa8b7ace Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Mon, 18 Mar 2024 21:14:28 +0100 Subject: [PATCH] start WAT compiler backend --- .gitignore | 1 + backend_wat.go | 252 +++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 12 +++ parser.go | 9 +- validator.go | 27 ++++-- 5 files changed, 290 insertions(+), 11 deletions(-) create mode 100644 backend_wat.go diff --git a/.gitignore b/.gitignore index 86a7c8e..0c47708 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ compiler +out.wat diff --git a/backend_wat.go b/backend_wat.go new file mode 100644 index 0000000..6770516 --- /dev/null +++ b/backend_wat.go @@ -0,0 +1,252 @@ +package main + +import ( + "errors" + "strconv" +) + +func getWATType(t Type) string { + // TODO: tuples? + + if t.Type != Type_Primitive { + panic("not implemented") // TODO: non-primitive types + } + + primitive := t.Value.(PrimitiveType) + switch primitive { + case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: + return "i32" + case Primitive_I64, Primitive_U64: + return "i64" + case Primitive_F32: + return "f32" + case Primitive_F64: + return "f64" + case Primitive_Bool: + return "i32" + } + + panic("unhandled type") +} + +func getTypeCast(primitive PrimitiveType) string { + switch primitive { + case Primitive_I8: + return "i32.extend8_s\n" + case Primitive_U8: + return "i32.const 255\ni32.and\n" + case Primitive_I16: + return "i32.extend16_s\n" + case Primitive_U16: + return "i32.const 65535\ni32.and\n" + case Primitive_Bool: + return "i32.const 1\ni32.and\n" + } + + return "" +} + +func pushConstantNumberWAT(primitive PrimitiveType, value any) string { + switch primitive { + case Primitive_I8, Primitive_I16, Primitive_I32: + return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n" + case Primitive_U8, Primitive_U16, Primitive_U32: + return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n" + case Primitive_I64: + return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" + case Primitive_U64: + return "u64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" + case Primitive_F32: + return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" + case Primitive_F64: + return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" + } + + panic("invalid type") +} + +func upcastTypeWAT(from PrimitiveType, to PrimitiveType) (string, error) { + // TODO: refactor + + if from == to { + return "", nil + } + + if from == Primitive_Bool || to == Primitive_Bool { + return "", errors.New("cannot upcast from or to bool") + } + + if from == Primitive_F32 && to == Primitive_F64 { + return "f64.promote_f32\n", nil + } + + if from == Primitive_F64 && to == Primitive_F32 { + return "f32.demote_f64\n", nil + } + + if isFloatingPoint(from) || isFloatingPoint(to) { + return "", errors.New("cannot upcast int from/to float") + } + + wat := "" + + if getBits(from) == 64 && getBits(to) < 64 { + wat += "i32.wrap_i64\n" + } + + if getBits(from) < 64 && getBits(to) == 64 { + if to == Primitive_I64 { + wat += "i64.extend_i32_s\n" + } else { + wat += "i64.extend_i32_u\n" + } + } + + switch to { + case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: + wat += getTypeCast(to) + } + + return wat, nil +} + +func compileExpressionWAT(expr Expression, block Block) (string, error) { + switch expr.Type { + case Expression_Assignment: + + case Expression_Literal: + lit := expr.Value.(LiteralExpression) + switch lit.Literal.Type { + case Literal_Number: + return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil + case Literal_Boolean: + if lit.Literal.Value.(bool) { + return "i32.const 1\n", nil + } else { + return "i32.const 0\n", nil + } + case Literal_String: + panic("not implemented") + } + case Expression_VariableReference: + ref := expr.Value.(VariableReferenceExpression) + return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n", nil + case Expression_Arithmetic: + arith := expr.Value.(ArithmeticExpression) + + watLeft, err := compileExpressionWAT(arith.Left, block) + if err != nil { + return "", err + } + + watRight, err := compileExpressionWAT(arith.Right, block) + if err != nil { + return "", err + } + + //TODO: upcast expressions and perform operation + + // TODO: cast result + return watLeft + watRight, nil + case Expression_Tuple: + } + + return "", nil +} + +func compileStatementWAT(stmt Statement, block Block) (string, error) { + switch stmt.Type { + case Statement_Expression: + expr := stmt.Value.(ExpressionStatement) + wat, err := compileExpressionWAT(expr.Expression, block) + if err != nil { + return "", err + } + + return wat + "drop\n", nil + case Statement_Block: + block := stmt.Value.(BlockStatement) + wat, err := compileBlockWAT(block.Block) + if err != nil { + return "", err + } + + return wat, nil + case Statement_Return: + ret := stmt.Value.(ReturnStatement) + wat, err := compileExpressionWAT(*ret.Value, block) + if err != nil { + return "", err + } + + return wat + "return\n", nil + case Statement_DeclareLocalVariable: + dlv := stmt.Value.(DeclareLocalVariableStatement) + if dlv.Initializer == nil { + return "", nil + } + + wat, err := compileExpressionWAT(*dlv.Initializer, block) + if err != nil { + return "", err + } + + return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil + } + + return "", nil +} + +func compileBlockWAT(block Block) (string, error) { + blockWAT := "" + + for _, stmt := range block.Statements { + wat, err := compileStatementWAT(stmt, block) + if err != nil { + return "", err + } + + blockWAT += wat + } + + return blockWAT, nil +} + +func compileFunctionWAT(function ParsedFunction) (string, error) { + funcWAT := "(func $" + function.Name + "\n" + + for _, local := range function.Locals { + pfx := "" + if local.IsParameter { + pfx = "param" + } else { + pfx = "local" + } + funcWAT += "\t(" + pfx + " $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" + } + + wat, err := compileBlockWAT(function.Body) + if err != nil { + return "", err + } + + funcWAT += wat + + return funcWAT + ") (export \"" + function.Name + "\" (func $" + function.Name + "))", nil +} + +func backendWAT(file ParsedFile) (string, error) { + module := "(module (memory 1)\n" + + for _, function := range file.Functions { + wat, err := compileFunctionWAT(function) + if err != nil { + return "", err + } + + module += wat + } + + module += ")" + return module, nil +} diff --git a/main.go b/main.go index 634eb28..7598427 100644 --- a/main.go +++ b/main.go @@ -51,4 +51,16 @@ func main() { } log.Printf("Validated:\n%+#v\n\n", parsed) + + wat, err := backendWAT(*parsed) + if err != nil { + if c, ok := err.(CompilerError); ok { + log.Fatalln(err, "\nv- here\n"+string([]rune(string(content))[c.Position:])) + } + + log.Fatalln(err) + } + + log.Println("WAT: " + wat) + os.WriteFile("out.wat", []byte(wat), 0o644) } diff --git a/parser.go b/parser.go index 95b146d..f449b8e 100644 --- a/parser.go +++ b/parser.go @@ -114,11 +114,14 @@ type TupleExpression struct { } type Local struct { - Name string - Type Type + Name string + Type Type + IsParameter bool + Index int // unique, 0-based index of the local in the current function } type Block struct { + Parent *Block // TODO: implement Statements []Statement Locals map[string]Local } @@ -133,6 +136,7 @@ type ParsedFunction struct { Parameters []ParsedParameter ReturnType Type Body Block + Locals []Local // All of the locals of the function, ordered by their index } type Import struct { @@ -576,6 +580,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) { token := pCopy.nextToken() if token.Type == Type_Separator && token.Value.(Separator) == Separator_Semicolon { + *p = pCopy return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: nil}}, nil } diff --git a/validator.go b/validator.go index 1c60c17..c39a366 100644 --- a/validator.go +++ b/validator.go @@ -2,7 +2,6 @@ package main import ( "errors" - "log" ) func createError(message string) error { @@ -160,16 +159,20 @@ func validateExpression(expr *Expression, block *Block) []error { return errors } -func validateStatement(stmt *Statement, block *Block) []error { +func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { var errors []error + // TODO: support references to variables in parent block + switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) errors = append(errors, validateExpression(&expression.Expression, block)...) + *stmt = Statement{Type: Statement_Expression, Value: expression} case Statement_Block: block := stmt.Value.(BlockStatement) - errors = append(errors, validateBlock(&block.Block)...) + errors = append(errors, validateBlock(&block.Block, functionLocals)...) + *stmt = Statement{Type: Statement_Block, Value: block} case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { @@ -185,13 +188,15 @@ func validateStatement(stmt *Statement, block *Block) []error { errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'")) } - block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType} + local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} + block.Locals[dlv.Variable] = local + *functionLocals = append(*functionLocals, local) } return errors } -func validateBlock(block *Block) []error { +func validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { @@ -200,7 +205,7 @@ func validateBlock(block *Block) []error { for i := range block.Statements { stmt := &block.Statements[i] - errors = append(errors, validateStatement(stmt, block)...) + errors = append(errors, validateStatement(stmt, block, functionLocals)...) } return errors @@ -209,15 +214,19 @@ func validateBlock(block *Block) []error { func validateFunction(function *ParsedFunction) []error { var errors []error + var locals []Local + body := &function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { - body.Locals[param.Name] = Local(param) + local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} + locals = append(locals, local) + body.Locals[param.Name] = local } - errors = append(errors, validateBlock(body)...) + errors = append(errors, validateBlock(body, &locals)...) - log.Printf("%+#v", body) + function.Locals = locals return errors }