start WAT compiler backend
This commit is contained in:
parent
ad32195fa2
commit
a7ec08b379
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
compiler
|
||||
out.wat
|
||||
|
252
backend_wat.go
Normal file
252
backend_wat.go
Normal file
@ -0,0 +1,252 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func getWATType(t Type) string {
|
||||
// TODO: tuples?
|
||||
|
||||
if t.Type != Type_Primitive {
|
||||
panic("not implemented") // TODO: non-primitive types
|
||||
}
|
||||
|
||||
primitive := t.Value.(PrimitiveType)
|
||||
switch primitive {
|
||||
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
|
||||
return "i32"
|
||||
case Primitive_I64, Primitive_U64:
|
||||
return "i64"
|
||||
case Primitive_F32:
|
||||
return "f32"
|
||||
case Primitive_F64:
|
||||
return "f64"
|
||||
case Primitive_Bool:
|
||||
return "i32"
|
||||
}
|
||||
|
||||
panic("unhandled type")
|
||||
}
|
||||
|
||||
func getTypeCast(primitive PrimitiveType) string {
|
||||
switch primitive {
|
||||
case Primitive_I8:
|
||||
return "i32.extend8_s\n"
|
||||
case Primitive_U8:
|
||||
return "i32.const 255\ni32.and\n"
|
||||
case Primitive_I16:
|
||||
return "i32.extend16_s\n"
|
||||
case Primitive_U16:
|
||||
return "i32.const 65535\ni32.and\n"
|
||||
case Primitive_Bool:
|
||||
return "i32.const 1\ni32.and\n"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
|
||||
switch primitive {
|
||||
case Primitive_I8, Primitive_I16, Primitive_I32:
|
||||
return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
||||
case Primitive_U8, Primitive_U16, Primitive_U32:
|
||||
return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
||||
case Primitive_I64:
|
||||
return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
||||
case Primitive_U64:
|
||||
return "u64.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
||||
case Primitive_F32:
|
||||
return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n"
|
||||
case Primitive_F64:
|
||||
return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n"
|
||||
}
|
||||
|
||||
panic("invalid type")
|
||||
}
|
||||
|
||||
func upcastTypeWAT(from PrimitiveType, to PrimitiveType) (string, error) {
|
||||
// TODO: refactor
|
||||
|
||||
if from == to {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if from == Primitive_Bool || to == Primitive_Bool {
|
||||
return "", errors.New("cannot upcast from or to bool")
|
||||
}
|
||||
|
||||
if from == Primitive_F32 && to == Primitive_F64 {
|
||||
return "f64.promote_f32\n", nil
|
||||
}
|
||||
|
||||
if from == Primitive_F64 && to == Primitive_F32 {
|
||||
return "f32.demote_f64\n", nil
|
||||
}
|
||||
|
||||
if isFloatingPoint(from) || isFloatingPoint(to) {
|
||||
return "", errors.New("cannot upcast int from/to float")
|
||||
}
|
||||
|
||||
wat := ""
|
||||
|
||||
if getBits(from) == 64 && getBits(to) < 64 {
|
||||
wat += "i32.wrap_i64\n"
|
||||
}
|
||||
|
||||
if getBits(from) < 64 && getBits(to) == 64 {
|
||||
if to == Primitive_I64 {
|
||||
wat += "i64.extend_i32_s\n"
|
||||
} else {
|
||||
wat += "i64.extend_i32_u\n"
|
||||
}
|
||||
}
|
||||
|
||||
switch to {
|
||||
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
|
||||
wat += getTypeCast(to)
|
||||
}
|
||||
|
||||
return wat, nil
|
||||
}
|
||||
|
||||
func compileExpressionWAT(expr Expression, block Block) (string, error) {
|
||||
switch expr.Type {
|
||||
case Expression_Assignment:
|
||||
|
||||
case Expression_Literal:
|
||||
lit := expr.Value.(LiteralExpression)
|
||||
switch lit.Literal.Type {
|
||||
case Literal_Number:
|
||||
return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil
|
||||
case Literal_Boolean:
|
||||
if lit.Literal.Value.(bool) {
|
||||
return "i32.const 1\n", nil
|
||||
} else {
|
||||
return "i32.const 0\n", nil
|
||||
}
|
||||
case Literal_String:
|
||||
panic("not implemented")
|
||||
}
|
||||
case Expression_VariableReference:
|
||||
ref := expr.Value.(VariableReferenceExpression)
|
||||
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n", nil
|
||||
case Expression_Arithmetic:
|
||||
arith := expr.Value.(ArithmeticExpression)
|
||||
|
||||
watLeft, err := compileExpressionWAT(arith.Left, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
watRight, err := compileExpressionWAT(arith.Right, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
//TODO: upcast expressions and perform operation
|
||||
|
||||
// TODO: cast result
|
||||
return watLeft + watRight, nil
|
||||
case Expression_Tuple:
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func compileStatementWAT(stmt Statement, block Block) (string, error) {
|
||||
switch stmt.Type {
|
||||
case Statement_Expression:
|
||||
expr := stmt.Value.(ExpressionStatement)
|
||||
wat, err := compileExpressionWAT(expr.Expression, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return wat + "drop\n", nil
|
||||
case Statement_Block:
|
||||
block := stmt.Value.(BlockStatement)
|
||||
wat, err := compileBlockWAT(block.Block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return wat, nil
|
||||
case Statement_Return:
|
||||
ret := stmt.Value.(ReturnStatement)
|
||||
wat, err := compileExpressionWAT(*ret.Value, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return wat + "return\n", nil
|
||||
case Statement_DeclareLocalVariable:
|
||||
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
||||
if dlv.Initializer == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
wat, err := compileExpressionWAT(*dlv.Initializer, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func compileBlockWAT(block Block) (string, error) {
|
||||
blockWAT := ""
|
||||
|
||||
for _, stmt := range block.Statements {
|
||||
wat, err := compileStatementWAT(stmt, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
blockWAT += wat
|
||||
}
|
||||
|
||||
return blockWAT, nil
|
||||
}
|
||||
|
||||
func compileFunctionWAT(function ParsedFunction) (string, error) {
|
||||
funcWAT := "(func $" + function.Name + "\n"
|
||||
|
||||
for _, local := range function.Locals {
|
||||
pfx := ""
|
||||
if local.IsParameter {
|
||||
pfx = "param"
|
||||
} else {
|
||||
pfx = "local"
|
||||
}
|
||||
funcWAT += "\t(" + pfx + " $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n"
|
||||
}
|
||||
|
||||
wat, err := compileBlockWAT(function.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
funcWAT += wat
|
||||
|
||||
return funcWAT + ") (export \"" + function.Name + "\" (func $" + function.Name + "))", nil
|
||||
}
|
||||
|
||||
func backendWAT(file ParsedFile) (string, error) {
|
||||
module := "(module (memory 1)\n"
|
||||
|
||||
for _, function := range file.Functions {
|
||||
wat, err := compileFunctionWAT(function)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
module += wat
|
||||
}
|
||||
|
||||
module += ")"
|
||||
return module, nil
|
||||
}
|
12
main.go
12
main.go
@ -51,4 +51,16 @@ func main() {
|
||||
}
|
||||
|
||||
log.Printf("Validated:\n%+#v\n\n", parsed)
|
||||
|
||||
wat, err := backendWAT(*parsed)
|
||||
if err != nil {
|
||||
if c, ok := err.(CompilerError); ok {
|
||||
log.Fatalln(err, "\nv- here\n"+string([]rune(string(content))[c.Position:]))
|
||||
}
|
||||
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
log.Println("WAT: " + wat)
|
||||
os.WriteFile("out.wat", []byte(wat), 0o644)
|
||||
}
|
||||
|
@ -114,11 +114,14 @@ type TupleExpression struct {
|
||||
}
|
||||
|
||||
type Local struct {
|
||||
Name string
|
||||
Type Type
|
||||
Name string
|
||||
Type Type
|
||||
IsParameter bool
|
||||
Index int // unique, 0-based index of the local in the current function
|
||||
}
|
||||
|
||||
type Block struct {
|
||||
Parent *Block // TODO: implement
|
||||
Statements []Statement
|
||||
Locals map[string]Local
|
||||
}
|
||||
@ -133,6 +136,7 @@ type ParsedFunction struct {
|
||||
Parameters []ParsedParameter
|
||||
ReturnType Type
|
||||
Body Block
|
||||
Locals []Local // All of the locals of the function, ordered by their index
|
||||
}
|
||||
|
||||
type Import struct {
|
||||
@ -576,6 +580,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) {
|
||||
|
||||
token := pCopy.nextToken()
|
||||
if token.Type == Type_Separator && token.Value.(Separator) == Separator_Semicolon {
|
||||
*p = pCopy
|
||||
return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: nil}}, nil
|
||||
}
|
||||
|
||||
|
27
validator.go
27
validator.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
func createError(message string) error {
|
||||
@ -160,16 +159,20 @@ func validateExpression(expr *Expression, block *Block) []error {
|
||||
return errors
|
||||
}
|
||||
|
||||
func validateStatement(stmt *Statement, block *Block) []error {
|
||||
func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
|
||||
var errors []error
|
||||
|
||||
// TODO: support references to variables in parent block
|
||||
|
||||
switch stmt.Type {
|
||||
case Statement_Expression:
|
||||
expression := stmt.Value.(ExpressionStatement)
|
||||
errors = append(errors, validateExpression(&expression.Expression, block)...)
|
||||
*stmt = Statement{Type: Statement_Expression, Value: expression}
|
||||
case Statement_Block:
|
||||
block := stmt.Value.(BlockStatement)
|
||||
errors = append(errors, validateBlock(&block.Block)...)
|
||||
errors = append(errors, validateBlock(&block.Block, functionLocals)...)
|
||||
*stmt = Statement{Type: Statement_Block, Value: block}
|
||||
case Statement_Return:
|
||||
ret := stmt.Value.(ReturnStatement)
|
||||
if ret.Value != nil {
|
||||
@ -185,13 +188,15 @@ func validateStatement(stmt *Statement, block *Block) []error {
|
||||
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'"))
|
||||
}
|
||||
|
||||
block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType}
|
||||
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
|
||||
block.Locals[dlv.Variable] = local
|
||||
*functionLocals = append(*functionLocals, local)
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func validateBlock(block *Block) []error {
|
||||
func validateBlock(block *Block, functionLocals *[]Local) []error {
|
||||
var errors []error
|
||||
|
||||
if block.Locals == nil {
|
||||
@ -200,7 +205,7 @@ func validateBlock(block *Block) []error {
|
||||
|
||||
for i := range block.Statements {
|
||||
stmt := &block.Statements[i]
|
||||
errors = append(errors, validateStatement(stmt, block)...)
|
||||
errors = append(errors, validateStatement(stmt, block, functionLocals)...)
|
||||
}
|
||||
|
||||
return errors
|
||||
@ -209,15 +214,19 @@ func validateBlock(block *Block) []error {
|
||||
func validateFunction(function *ParsedFunction) []error {
|
||||
var errors []error
|
||||
|
||||
var locals []Local
|
||||
|
||||
body := &function.Body
|
||||
body.Locals = make(map[string]Local)
|
||||
for _, param := range function.Parameters {
|
||||
body.Locals[param.Name] = Local(param)
|
||||
local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}
|
||||
locals = append(locals, local)
|
||||
body.Locals[param.Name] = local
|
||||
}
|
||||
|
||||
errors = append(errors, validateBlock(body)...)
|
||||
errors = append(errors, validateBlock(body, &locals)...)
|
||||
|
||||
log.Printf("%+#v", body)
|
||||
function.Locals = locals
|
||||
|
||||
return errors
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user