package main import ( "errors" "strconv" ) func getWATType(t Type) string { // TODO: tuples? if t.Type != Type_Primitive { panic("not implemented") // TODO: non-primitive types } primitive := t.Value.(PrimitiveType) switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: return "i32" case Primitive_I64, Primitive_U64: return "i64" case Primitive_F32: return "f32" case Primitive_F64: return "f64" case Primitive_Bool: return "i32" } panic("unhandled type") } func getTypeCast(primitive PrimitiveType) string { switch primitive { case Primitive_I8: return "i32.extend8_s\n" case Primitive_U8: return "i32.const 255\ni32.and\n" case Primitive_I16: return "i32.extend16_s\n" case Primitive_U16: return "i32.const 65535\ni32.and\n" case Primitive_Bool: return "i32.const 1\ni32.and\n" } return "" } func pushConstantNumberWAT(primitive PrimitiveType, value any) string { switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32: return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n" case Primitive_U8, Primitive_U16, Primitive_U32: return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n" case Primitive_I64: return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" case Primitive_U64: return "u64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" case Primitive_F32: return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" case Primitive_F64: return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" } panic("invalid type") } func upcastTypeWAT(from PrimitiveType, to PrimitiveType) (string, error) { // TODO: refactor if from == to { return "", nil } if from == Primitive_Bool || to == Primitive_Bool { return "", errors.New("cannot upcast from or to bool") } if from == Primitive_F32 && to == Primitive_F64 { return "f64.promote_f32\n", nil } if from == Primitive_F64 && to == Primitive_F32 { return "f32.demote_f64\n", nil } if isFloatingPoint(from) || isFloatingPoint(to) { return "", errors.New("cannot upcast int from/to float") } wat := "" if getBits(from) == 64 && getBits(to) < 64 { wat += "i32.wrap_i64\n" } if getBits(from) < 64 && getBits(to) == 64 { if to == Primitive_I64 { wat += "i64.extend_i32_s\n" } else { wat += "i64.extend_i32_u\n" } } switch to { case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: wat += getTypeCast(to) } return wat, nil } func compileExpressionWAT(expr Expression, block Block) (string, error) { switch expr.Type { case Expression_Assignment: case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { case Literal_Number: return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil case Literal_Boolean: if lit.Literal.Value.(bool) { return "i32.const 1\n", nil } else { return "i32.const 0\n", nil } case Literal_String: panic("not implemented") } case Expression_VariableReference: ref := expr.Value.(VariableReferenceExpression) return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n", nil case Expression_Arithmetic: arith := expr.Value.(ArithmeticExpression) watLeft, err := compileExpressionWAT(arith.Left, block) if err != nil { return "", err } watRight, err := compileExpressionWAT(arith.Right, block) if err != nil { return "", err } //TODO: upcast expressions and perform operation // TODO: cast result return watLeft + watRight, nil case Expression_Tuple: } return "", nil } func compileStatementWAT(stmt Statement, block Block) (string, error) { switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) wat, err := compileExpressionWAT(expr.Expression, block) if err != nil { return "", err } return wat + "drop\n", nil case Statement_Block: block := stmt.Value.(BlockStatement) wat, err := compileBlockWAT(block.Block) if err != nil { return "", err } return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) wat, err := compileExpressionWAT(*ret.Value, block) if err != nil { return "", err } return wat + "return\n", nil case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer == nil { return "", nil } wat, err := compileExpressionWAT(*dlv.Initializer, block) if err != nil { return "", err } return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil } return "", nil } func compileBlockWAT(block Block) (string, error) { blockWAT := "" for _, stmt := range block.Statements { wat, err := compileStatementWAT(stmt, block) if err != nil { return "", err } blockWAT += wat } return blockWAT, nil } func compileFunctionWAT(function ParsedFunction) (string, error) { funcWAT := "(func $" + function.Name + "\n" for _, local := range function.Locals { pfx := "" if local.IsParameter { pfx = "param" } else { pfx = "local" } funcWAT += "\t(" + pfx + " $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" } wat, err := compileBlockWAT(function.Body) if err != nil { return "", err } funcWAT += wat return funcWAT + ") (export \"" + function.Name + "\" (func $" + function.Name + "))", nil } func backendWAT(file ParsedFile) (string, error) { module := "(module (memory 1)\n" for _, function := range file.Functions { wat, err := compileFunctionWAT(function) if err != nil { return "", err } module += wat } module += ")" return module, nil }