elysium/backend_wat.go

280 lines
6.0 KiB
Go
Raw Normal View History

2024-03-18 21:14:28 +01:00
package main
import (
"errors"
"strconv"
)
2024-03-19 08:00:49 +01:00
func getPrimitiveWATType(primitive PrimitiveType) string {
2024-03-18 21:14:28 +01:00
switch primitive {
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
return "i32"
case Primitive_I64, Primitive_U64:
return "i64"
case Primitive_F32:
return "f32"
case Primitive_F64:
return "f64"
case Primitive_Bool:
return "i32"
}
panic("unhandled type")
}
2024-03-19 08:00:49 +01:00
func getWATType(t Type) string {
// TODO: tuples?
if t.Type != Type_Primitive {
panic("not implemented") // TODO: non-primitive types
}
primitive := t.Value.(PrimitiveType)
return getPrimitiveWATType(primitive)
}
2024-03-18 21:14:28 +01:00
func getTypeCast(primitive PrimitiveType) string {
switch primitive {
case Primitive_I8:
return "i32.extend8_s\n"
case Primitive_U8:
return "i32.const 255\ni32.and\n"
case Primitive_I16:
return "i32.extend16_s\n"
case Primitive_U16:
return "i32.const 65535\ni32.and\n"
case Primitive_Bool:
return "i32.const 1\ni32.and\n"
}
return ""
}
func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
switch primitive {
case Primitive_I8, Primitive_I16, Primitive_I32:
return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n"
case Primitive_U8, Primitive_U16, Primitive_U32:
return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
case Primitive_I64:
return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n"
case Primitive_U64:
return "u64.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
case Primitive_F32:
return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n"
case Primitive_F64:
return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n"
}
panic("invalid type")
}
func upcastTypeWAT(from PrimitiveType, to PrimitiveType) (string, error) {
// TODO: refactor
if from == to {
return "", nil
}
if from == Primitive_Bool || to == Primitive_Bool {
return "", errors.New("cannot upcast from or to bool")
}
2024-03-19 08:00:49 +01:00
fromFloat := isFloatingPoint(from)
toFloat := isFloatingPoint(to)
if fromFloat && toFloat {
if to == Primitive_F32 {
return "f32.demote_f64\n", nil
} else {
return "f64.promote_f32\n", nil
}
2024-03-18 21:14:28 +01:00
}
2024-03-19 08:00:49 +01:00
if toFloat {
suffix := ""
if isUnsignedInt(to) {
suffix = "u"
} else {
suffix = "s"
}
return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
2024-03-18 21:14:28 +01:00
}
2024-03-19 08:00:49 +01:00
if fromFloat {
suffix := ""
if isUnsignedInt(to) {
suffix = "u"
} else {
suffix = "s"
}
return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
2024-03-18 21:14:28 +01:00
}
2024-03-19 08:00:49 +01:00
if getBits(from) == getBits(to) {
return "", nil
}
2024-03-18 21:14:28 +01:00
2024-03-19 08:00:49 +01:00
if getPrimitiveWATType(from) == getPrimitiveWATType(to) {
// TODO: cast if to is smaller than from
2024-03-18 21:14:28 +01:00
}
if getBits(from) < 64 && getBits(to) == 64 {
2024-03-19 08:00:49 +01:00
suffix := ""
if isUnsignedInt(from) {
suffix = "u"
2024-03-18 21:14:28 +01:00
} else {
2024-03-19 08:00:49 +01:00
suffix = "s"
2024-03-18 21:14:28 +01:00
}
2024-03-19 08:00:49 +01:00
return "i64.extend_i32_" + suffix + "\n", nil
2024-03-18 21:14:28 +01:00
}
2024-03-19 08:00:49 +01:00
// TODO: cast down from 64 to 32
return "", nil
2024-03-18 21:14:28 +01:00
}
func compileExpressionWAT(expr Expression, block Block) (string, error) {
switch expr.Type {
case Expression_Assignment:
case Expression_Literal:
lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type {
case Literal_Number:
return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil
case Literal_Boolean:
if lit.Literal.Value.(bool) {
return "i32.const 1\n", nil
} else {
return "i32.const 0\n", nil
}
case Literal_String:
panic("not implemented")
}
case Expression_VariableReference:
ref := expr.Value.(VariableReferenceExpression)
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n", nil
case Expression_Arithmetic:
arith := expr.Value.(ArithmeticExpression)
watLeft, err := compileExpressionWAT(arith.Left, block)
if err != nil {
return "", err
}
watRight, err := compileExpressionWAT(arith.Right, block)
if err != nil {
return "", err
}
//TODO: upcast expressions and perform operation
// TODO: cast result
return watLeft + watRight, nil
case Expression_Tuple:
}
return "", nil
}
func compileStatementWAT(stmt Statement, block Block) (string, error) {
switch stmt.Type {
case Statement_Expression:
expr := stmt.Value.(ExpressionStatement)
wat, err := compileExpressionWAT(expr.Expression, block)
if err != nil {
return "", err
}
return wat + "drop\n", nil
case Statement_Block:
block := stmt.Value.(BlockStatement)
wat, err := compileBlockWAT(block.Block)
if err != nil {
return "", err
}
return wat, nil
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
wat, err := compileExpressionWAT(*ret.Value, block)
if err != nil {
return "", err
}
return wat + "return\n", nil
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer == nil {
return "", nil
}
wat, err := compileExpressionWAT(*dlv.Initializer, block)
if err != nil {
return "", err
}
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
}
return "", nil
}
func compileBlockWAT(block Block) (string, error) {
blockWAT := ""
for _, stmt := range block.Statements {
wat, err := compileStatementWAT(stmt, block)
if err != nil {
return "", err
}
blockWAT += wat
}
return blockWAT, nil
}
func compileFunctionWAT(function ParsedFunction) (string, error) {
funcWAT := "(func $" + function.Name + "\n"
for _, local := range function.Locals {
pfx := ""
if local.IsParameter {
pfx = "param"
} else {
pfx = "local"
}
funcWAT += "\t(" + pfx + " $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n"
}
wat, err := compileBlockWAT(function.Body)
if err != nil {
return "", err
}
funcWAT += wat
return funcWAT + ") (export \"" + function.Name + "\" (func $" + function.Name + "))", nil
}
func backendWAT(file ParsedFile) (string, error) {
module := "(module (memory 1)\n"
for _, function := range file.Functions {
wat, err := compileFunctionWAT(function)
if err != nil {
return "", err
}
module += wat
}
module += ")"
return module, nil
}