2024-03-18 21:14:28 +01:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"strconv"
|
|
|
|
)
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
func getPrimitiveWATType(primitive PrimitiveType) string {
|
2024-03-18 21:14:28 +01:00
|
|
|
switch primitive {
|
|
|
|
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
|
|
|
|
return "i32"
|
|
|
|
case Primitive_I64, Primitive_U64:
|
|
|
|
return "i64"
|
|
|
|
case Primitive_F32:
|
|
|
|
return "f32"
|
|
|
|
case Primitive_F64:
|
|
|
|
return "f64"
|
|
|
|
case Primitive_Bool:
|
|
|
|
return "i32"
|
|
|
|
}
|
|
|
|
|
|
|
|
panic("unhandled type")
|
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
func getWATType(t Type) string {
|
|
|
|
// TODO: tuples?
|
|
|
|
|
|
|
|
if t.Type != Type_Primitive {
|
|
|
|
panic("not implemented") // TODO: non-primitive types
|
|
|
|
}
|
|
|
|
|
|
|
|
primitive := t.Value.(PrimitiveType)
|
|
|
|
return getPrimitiveWATType(primitive)
|
|
|
|
}
|
|
|
|
|
2024-03-18 21:14:28 +01:00
|
|
|
func getTypeCast(primitive PrimitiveType) string {
|
|
|
|
switch primitive {
|
|
|
|
case Primitive_I8:
|
|
|
|
return "i32.extend8_s\n"
|
|
|
|
case Primitive_U8:
|
|
|
|
return "i32.const 255\ni32.and\n"
|
|
|
|
case Primitive_I16:
|
|
|
|
return "i32.extend16_s\n"
|
|
|
|
case Primitive_U16:
|
|
|
|
return "i32.const 65535\ni32.and\n"
|
|
|
|
case Primitive_Bool:
|
|
|
|
return "i32.const 1\ni32.and\n"
|
|
|
|
}
|
|
|
|
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
|
|
|
|
switch primitive {
|
|
|
|
case Primitive_I8, Primitive_I16, Primitive_I32:
|
|
|
|
return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
|
|
|
case Primitive_U8, Primitive_U16, Primitive_U32:
|
|
|
|
return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
|
|
|
case Primitive_I64:
|
|
|
|
return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
|
|
|
case Primitive_U64:
|
2024-03-19 12:19:19 +01:00
|
|
|
return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
2024-03-18 21:14:28 +01:00
|
|
|
case Primitive_F32:
|
|
|
|
return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n"
|
|
|
|
case Primitive_F64:
|
|
|
|
return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n"
|
|
|
|
}
|
|
|
|
|
|
|
|
panic("invalid type")
|
|
|
|
}
|
|
|
|
|
2024-03-19 12:19:19 +01:00
|
|
|
func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
|
2024-03-18 21:14:28 +01:00
|
|
|
if from == to {
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if from == Primitive_Bool || to == Primitive_Bool {
|
|
|
|
return "", errors.New("cannot upcast from or to bool")
|
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
fromFloat := isFloatingPoint(from)
|
|
|
|
toFloat := isFloatingPoint(to)
|
|
|
|
if fromFloat && toFloat {
|
|
|
|
if to == Primitive_F32 {
|
|
|
|
return "f32.demote_f64\n", nil
|
|
|
|
} else {
|
|
|
|
return "f64.promote_f32\n", nil
|
|
|
|
}
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
if toFloat {
|
|
|
|
suffix := ""
|
|
|
|
if isUnsignedInt(to) {
|
|
|
|
suffix = "u"
|
|
|
|
} else {
|
|
|
|
suffix = "s"
|
|
|
|
}
|
|
|
|
|
|
|
|
return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
if fromFloat {
|
|
|
|
suffix := ""
|
|
|
|
|
|
|
|
if isUnsignedInt(to) {
|
|
|
|
suffix = "u"
|
|
|
|
} else {
|
|
|
|
suffix = "s"
|
|
|
|
}
|
|
|
|
|
|
|
|
return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
if getBits(from) == getBits(to) {
|
|
|
|
return "", nil
|
|
|
|
}
|
2024-03-18 21:14:28 +01:00
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
if getPrimitiveWATType(from) == getPrimitiveWATType(to) {
|
2024-03-19 10:54:21 +01:00
|
|
|
if getBits(to) < getBits(from) {
|
|
|
|
return getTypeCast(to), nil
|
|
|
|
}
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if getBits(from) < 64 && getBits(to) == 64 {
|
2024-03-19 08:00:49 +01:00
|
|
|
suffix := ""
|
|
|
|
if isUnsignedInt(from) {
|
|
|
|
suffix = "u"
|
2024-03-18 21:14:28 +01:00
|
|
|
} else {
|
2024-03-19 08:00:49 +01:00
|
|
|
suffix = "s"
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 08:00:49 +01:00
|
|
|
return "i64.extend_i32_" + suffix + "\n", nil
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 10:54:21 +01:00
|
|
|
return "i32.wrap_i64\n" + getTypeCast(to), nil
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func compileExpressionWAT(expr Expression, block Block) (string, error) {
|
|
|
|
switch expr.Type {
|
|
|
|
case Expression_Assignment:
|
|
|
|
|
|
|
|
case Expression_Literal:
|
|
|
|
lit := expr.Value.(LiteralExpression)
|
|
|
|
switch lit.Literal.Type {
|
|
|
|
case Literal_Number:
|
|
|
|
return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil
|
|
|
|
case Literal_Boolean:
|
|
|
|
if lit.Literal.Value.(bool) {
|
|
|
|
return "i32.const 1\n", nil
|
|
|
|
} else {
|
|
|
|
return "i32.const 0\n", nil
|
|
|
|
}
|
|
|
|
case Literal_String:
|
|
|
|
panic("not implemented")
|
|
|
|
}
|
|
|
|
case Expression_VariableReference:
|
|
|
|
ref := expr.Value.(VariableReferenceExpression)
|
2024-03-19 12:19:19 +01:00
|
|
|
|
|
|
|
cast := ""
|
|
|
|
if expr.ValueType.Type == Type_Primitive {
|
|
|
|
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
|
|
|
|
}
|
|
|
|
|
|
|
|
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil
|
2024-03-18 21:14:28 +01:00
|
|
|
case Expression_Arithmetic:
|
|
|
|
arith := expr.Value.(ArithmeticExpression)
|
|
|
|
|
2024-03-19 10:54:21 +01:00
|
|
|
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
|
|
|
|
exprType := expr.ValueType.Value.(PrimitiveType)
|
|
|
|
|
2024-03-18 21:14:28 +01:00
|
|
|
watLeft, err := compileExpressionWAT(arith.Left, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2024-03-19 12:19:19 +01:00
|
|
|
castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType)
|
2024-03-19 10:54:21 +01:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2024-03-18 21:14:28 +01:00
|
|
|
watRight, err := compileExpressionWAT(arith.Right, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2024-03-19 12:19:19 +01:00
|
|
|
castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType)
|
2024-03-19 10:54:21 +01:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
op := ""
|
2024-03-18 21:14:28 +01:00
|
|
|
|
2024-03-19 10:54:21 +01:00
|
|
|
suffix := ""
|
|
|
|
if isUnsignedInt(exprType) {
|
|
|
|
suffix = "u"
|
|
|
|
} else {
|
|
|
|
suffix = "s"
|
|
|
|
}
|
|
|
|
|
|
|
|
switch arith.Operation {
|
|
|
|
case Arithmetic_Add:
|
|
|
|
op = getPrimitiveWATType(exprType) + ".add\n"
|
|
|
|
case Arithmetic_Sub:
|
|
|
|
op = getPrimitiveWATType(exprType) + ".sub\n"
|
|
|
|
case Arithmetic_Mul:
|
|
|
|
op = getPrimitiveWATType(exprType) + ".mul\n"
|
|
|
|
case Arithmetic_Div:
|
2024-03-19 12:19:19 +01:00
|
|
|
op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n"
|
2024-03-19 10:54:21 +01:00
|
|
|
case Arithmetic_Mod:
|
2024-03-19 12:19:19 +01:00
|
|
|
op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n"
|
2024-03-19 10:54:21 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 12:19:19 +01:00
|
|
|
return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil
|
2024-03-18 21:14:28 +01:00
|
|
|
case Expression_Tuple:
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func compileStatementWAT(stmt Statement, block Block) (string, error) {
|
|
|
|
switch stmt.Type {
|
|
|
|
case Statement_Expression:
|
|
|
|
expr := stmt.Value.(ExpressionStatement)
|
|
|
|
wat, err := compileExpressionWAT(expr.Expression, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return wat + "drop\n", nil
|
|
|
|
case Statement_Block:
|
|
|
|
block := stmt.Value.(BlockStatement)
|
|
|
|
wat, err := compileBlockWAT(block.Block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return wat, nil
|
|
|
|
case Statement_Return:
|
|
|
|
ret := stmt.Value.(ReturnStatement)
|
|
|
|
wat, err := compileExpressionWAT(*ret.Value, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return wat + "return\n", nil
|
|
|
|
case Statement_DeclareLocalVariable:
|
|
|
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
|
|
|
if dlv.Initializer == nil {
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
wat, err := compileExpressionWAT(*dlv.Initializer, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func compileBlockWAT(block Block) (string, error) {
|
|
|
|
blockWAT := ""
|
|
|
|
|
|
|
|
for _, stmt := range block.Statements {
|
|
|
|
wat, err := compileStatementWAT(stmt, block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
blockWAT += wat
|
|
|
|
}
|
|
|
|
|
|
|
|
return blockWAT, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func compileFunctionWAT(function ParsedFunction) (string, error) {
|
|
|
|
funcWAT := "(func $" + function.Name + "\n"
|
|
|
|
|
|
|
|
for _, local := range function.Locals {
|
2024-03-19 12:48:06 +01:00
|
|
|
if !local.IsParameter {
|
|
|
|
continue
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
2024-03-19 12:48:06 +01:00
|
|
|
funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n"
|
2024-03-18 21:14:28 +01:00
|
|
|
}
|
|
|
|
|
2024-03-19 10:54:21 +01:00
|
|
|
// TODO: tuples
|
|
|
|
funcWAT += "\t(result " + getWATType(function.ReturnType) + ")\n"
|
|
|
|
|
2024-03-19 12:48:06 +01:00
|
|
|
for _, local := range function.Locals {
|
|
|
|
if local.IsParameter {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n"
|
|
|
|
}
|
|
|
|
|
2024-03-18 21:14:28 +01:00
|
|
|
wat, err := compileBlockWAT(function.Body)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
funcWAT += wat
|
|
|
|
|
|
|
|
return funcWAT + ") (export \"" + function.Name + "\" (func $" + function.Name + "))", nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func backendWAT(file ParsedFile) (string, error) {
|
|
|
|
module := "(module (memory 1)\n"
|
|
|
|
|
|
|
|
for _, function := range file.Functions {
|
|
|
|
wat, err := compileFunctionWAT(function)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
module += wat
|
|
|
|
}
|
|
|
|
|
|
|
|
module += ")"
|
|
|
|
return module, nil
|
|
|
|
}
|