575 lines
13 KiB
Go
575 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
)
|
|
|
|
type Compiler struct {
|
|
Files []*ParsedFile
|
|
Wasm64 bool
|
|
|
|
CurrentBlock *Block
|
|
CurrentFunction *ParsedFunction
|
|
}
|
|
|
|
func getPrimitiveWATType(primitive PrimitiveType) string {
|
|
switch primitive {
|
|
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
|
|
return "i32"
|
|
case Primitive_I64, Primitive_U64:
|
|
return "i64"
|
|
case Primitive_F32:
|
|
return "f32"
|
|
case Primitive_F64:
|
|
return "f64"
|
|
case Primitive_Bool:
|
|
return "i32"
|
|
}
|
|
|
|
panic(fmt.Sprintf("unhandled type in getPrimitiveWATType(): %s", primitive))
|
|
}
|
|
|
|
func safeASCIIIdentifier(identifier string) string {
|
|
ascii := ""
|
|
for _, rune := range identifier {
|
|
if rune < unicode.MaxASCII && (unicode.IsLetter(rune) || unicode.IsDigit(rune)) {
|
|
ascii += string(rune)
|
|
continue
|
|
}
|
|
|
|
ascii += "$" + strconv.Itoa(int(rune))
|
|
}
|
|
|
|
return ascii
|
|
}
|
|
|
|
func getTypeCast(primitive PrimitiveType) string {
|
|
switch primitive {
|
|
case Primitive_I8:
|
|
return "i32.extend8_s\n"
|
|
case Primitive_U8:
|
|
return "i32.const 255\ni32.and\n"
|
|
case Primitive_I16:
|
|
return "i32.extend16_s\n"
|
|
case Primitive_U16:
|
|
return "i32.const 65535\ni32.and\n"
|
|
case Primitive_Bool:
|
|
return "i32.const 1\ni32.and\n"
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
|
|
switch primitive {
|
|
case Primitive_I8, Primitive_I16, Primitive_I32:
|
|
return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
|
case Primitive_U8, Primitive_U16, Primitive_U32:
|
|
return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
|
case Primitive_I64:
|
|
return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n"
|
|
case Primitive_U64:
|
|
return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n"
|
|
case Primitive_F32:
|
|
return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n"
|
|
case Primitive_F64:
|
|
return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n"
|
|
}
|
|
|
|
panic(fmt.Sprintf("invalid type passed to pushConstantNumberWAT(): %s", primitive))
|
|
}
|
|
|
|
func (c *Compiler) getAddressWATType() string {
|
|
if c.Wasm64 {
|
|
return "i64"
|
|
} else {
|
|
return "i32"
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) getWATType(t Type) string {
|
|
switch t.Type {
|
|
case Type_Primitive:
|
|
return getPrimitiveWATType(t.Value.(PrimitiveType))
|
|
case Type_Named, Type_Array:
|
|
return c.getAddressWATType()
|
|
case Type_Tuple:
|
|
panic(fmt.Sprintf("tuple type passed to getWATType(): %s", t))
|
|
}
|
|
|
|
panic(fmt.Sprintf("type not implemented in getWATType(): %s", t))
|
|
}
|
|
|
|
func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
|
|
if from == to {
|
|
return "", nil
|
|
}
|
|
|
|
if from == Primitive_Bool || to == Primitive_Bool {
|
|
return "", errors.New("cannot upcast from or to bool")
|
|
}
|
|
|
|
fromFloat := isFloatingPoint(from)
|
|
toFloat := isFloatingPoint(to)
|
|
if fromFloat && toFloat {
|
|
if to == Primitive_F32 {
|
|
return "f32.demote_f64\n", nil
|
|
} else {
|
|
return "f64.promote_f32\n", nil
|
|
}
|
|
}
|
|
|
|
if toFloat {
|
|
suffix := ""
|
|
if isUnsignedInt(to) {
|
|
suffix = "u"
|
|
} else {
|
|
suffix = "s"
|
|
}
|
|
|
|
return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
|
|
}
|
|
|
|
if fromFloat {
|
|
suffix := ""
|
|
|
|
if isUnsignedInt(to) {
|
|
suffix = "u"
|
|
} else {
|
|
suffix = "s"
|
|
}
|
|
|
|
return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil
|
|
}
|
|
|
|
if getBits(from) == getBits(to) {
|
|
return "", nil
|
|
}
|
|
|
|
if getPrimitiveWATType(from) == getPrimitiveWATType(to) {
|
|
if getBits(to) < getBits(from) {
|
|
return getTypeCast(to), nil
|
|
}
|
|
|
|
return "", nil
|
|
}
|
|
|
|
if getBits(from) < 64 && getBits(to) == 64 {
|
|
suffix := ""
|
|
if isUnsignedInt(from) {
|
|
suffix = "u"
|
|
} else {
|
|
suffix = "s"
|
|
}
|
|
|
|
return "i64.extend_i32_" + suffix + "\n", nil
|
|
}
|
|
|
|
return "i32.wrap_i64\n" + getTypeCast(to), nil
|
|
}
|
|
|
|
func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (string, error) {
|
|
lhs := assignment.Lhs
|
|
|
|
exprWAT, err := c.compileExpressionWAT(assignment.Value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch lhs.Type {
|
|
case Expression_VariableReference:
|
|
ref := lhs.Value.(VariableReferenceExpression)
|
|
local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index)
|
|
return exprWAT + "local.tee $" + local + "\n", nil
|
|
case Expression_ArrayAccess:
|
|
panic("TODO") // TODO
|
|
case Expression_RawMemoryReference:
|
|
raw := lhs.Value.(RawMemoryReferenceExpression)
|
|
|
|
local := Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
|
|
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, local)
|
|
|
|
if raw.Type.Type != Type_Primitive {
|
|
panic("TODO") //TODO
|
|
}
|
|
|
|
addrWAT, err := c.compileExpressionWAT(raw.Address)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// TODO: should leave a copy of the stored value on the stack
|
|
return addrWAT + exprWAT +
|
|
"local.tee " + strconv.Itoa(local.Index) + "\n" +
|
|
c.getWATType(raw.Type) + ".store\n" +
|
|
"local.get " + strconv.Itoa(local.Index) + "\n", nil
|
|
}
|
|
|
|
panic("assignment expr not implemented")
|
|
}
|
|
|
|
func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
|
|
var err error
|
|
|
|
switch expr.Type {
|
|
case Expression_Assignment:
|
|
ass := expr.Value.(AssignmentExpression)
|
|
return c.compileAssignmentExpressionWAT(ass)
|
|
case Expression_Literal:
|
|
lit := expr.Value.(LiteralExpression)
|
|
switch lit.Literal.Type {
|
|
case Literal_Number:
|
|
return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil
|
|
case Literal_Boolean:
|
|
if lit.Literal.Value.(bool) {
|
|
return "i32.const 1\n", nil
|
|
} else {
|
|
return "i32.const 0\n", nil
|
|
}
|
|
case Literal_String:
|
|
panic("not implemented")
|
|
}
|
|
case Expression_VariableReference:
|
|
ref := expr.Value.(VariableReferenceExpression)
|
|
|
|
cast := ""
|
|
if expr.ValueType.Type == Type_Primitive {
|
|
// TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked
|
|
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
|
|
}
|
|
|
|
return "local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) + "\n" + cast, nil
|
|
case Expression_Binary:
|
|
binary := expr.Value.(BinaryExpression)
|
|
|
|
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
|
|
operandType := binary.Left.ValueType.Value.(PrimitiveType)
|
|
exprType := expr.ValueType.Value.(PrimitiveType)
|
|
|
|
watLeft, err := c.compileExpressionWAT(binary.Left)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
watRight, err := c.compileExpressionWAT(binary.Right)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
op := ""
|
|
|
|
suffix := ""
|
|
if isUnsignedInt(operandType) {
|
|
suffix = "u"
|
|
} else {
|
|
suffix = "s"
|
|
}
|
|
|
|
switch binary.Operation {
|
|
case Operation_Add:
|
|
op = getPrimitiveWATType(operandType) + ".add\n"
|
|
case Operation_Sub:
|
|
op = getPrimitiveWATType(operandType) + ".sub\n"
|
|
case Operation_Mul:
|
|
op = getPrimitiveWATType(operandType) + ".mul\n"
|
|
case Operation_Div:
|
|
op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n"
|
|
case Operation_Mod:
|
|
op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n"
|
|
case Operation_Greater:
|
|
op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n"
|
|
case Operation_Less:
|
|
op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n"
|
|
case Operation_GreaterEquals:
|
|
op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n"
|
|
case Operation_LessEquals:
|
|
op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n"
|
|
case Operation_NotEquals:
|
|
op = getPrimitiveWATType(operandType) + ".ne\n"
|
|
case Operation_Equals:
|
|
op = getPrimitiveWATType(operandType) + ".eq\n"
|
|
default:
|
|
panic("operation not implemented")
|
|
}
|
|
|
|
return watLeft + watRight + op + getTypeCast(exprType), nil
|
|
case Expression_Tuple:
|
|
tuple := expr.Value.(TupleExpression)
|
|
|
|
wat := ""
|
|
for _, member := range tuple.Members {
|
|
memberWAT, err := c.compileExpressionWAT(member)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
wat += memberWAT
|
|
}
|
|
|
|
return wat, nil
|
|
case Expression_FunctionCall:
|
|
fc := expr.Value.(FunctionCallExpression)
|
|
|
|
wat := ""
|
|
if fc.Parameters != nil {
|
|
wat, err = c.compileExpressionWAT(*fc.Parameters)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
return wat + "call $" + fc.Function + "\n", nil
|
|
case Expression_Negate:
|
|
neg := expr.Value.(NegateExpression)
|
|
exprType := expr.ValueType.Value.(PrimitiveType)
|
|
|
|
wat, err := c.compileExpressionWAT(neg.Value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
watType := getPrimitiveWATType(exprType)
|
|
if isSignedInt(exprType) || isUnsignedInt(exprType) {
|
|
return watType + ".const 0\n" + wat + watType + ".sub\n", nil
|
|
}
|
|
|
|
if isFloatingPoint(exprType) {
|
|
return watType + ".neg\n", nil
|
|
}
|
|
case Expression_Cast:
|
|
cast := expr.Value.(CastExpression)
|
|
|
|
wat, err := c.compileExpressionWAT(cast.Value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// TODO: fine, as it is currently only allowed for primitive types
|
|
fromType := cast.Value.ValueType.Value.(PrimitiveType)
|
|
toType := cast.Type.Value.(PrimitiveType)
|
|
castWAT, err := castPrimitiveWAT(fromType, toType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return wat + castWAT, nil
|
|
case Expression_RawMemoryReference:
|
|
raw := expr.Value.(RawMemoryReferenceExpression)
|
|
|
|
wat, err := c.compileExpressionWAT(raw.Address)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if raw.Type.Type == Type_Primitive {
|
|
wat += c.getWATType(raw.Type) + ".load\n"
|
|
}
|
|
|
|
return wat, nil
|
|
}
|
|
|
|
panic("expr not implemented")
|
|
}
|
|
|
|
func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) {
|
|
switch stmt.Type {
|
|
case Statement_Expression:
|
|
expr := stmt.Value.(ExpressionStatement)
|
|
wat, err := c.compileExpressionWAT(expr.Expression)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
numItems := 0
|
|
if expr.Expression.ValueType != nil {
|
|
numItems = 1
|
|
|
|
if expr.Expression.ValueType.Type == Type_Tuple {
|
|
numItems = len(expr.Expression.ValueType.Value.(TupleType).Types)
|
|
}
|
|
}
|
|
|
|
return wat + strings.Repeat("drop\n", numItems), nil
|
|
case Statement_Block:
|
|
block := stmt.Value.(BlockStatement)
|
|
wat, err := c.compileBlockWAT(block.Block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return wat, nil
|
|
case Statement_Return:
|
|
ret := stmt.Value.(ReturnStatement)
|
|
wat, err := c.compileExpressionWAT(*ret.Value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return wat + "return\n", nil
|
|
case Statement_DeclareLocalVariable:
|
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
|
if dlv.Initializer == nil {
|
|
return "", nil
|
|
}
|
|
|
|
wat, err := c.compileExpressionWAT(*dlv.Initializer)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
|
|
case Statement_If:
|
|
ifS := stmt.Value.(IfStatement)
|
|
|
|
conditionWAT, err := c.compileExpressionWAT(ifS.Condition)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
wat := ""
|
|
|
|
if ifS.ElseBlock != nil {
|
|
wat += "block\n"
|
|
}
|
|
|
|
// condition
|
|
wat += "block\n"
|
|
|
|
wat += conditionWAT
|
|
wat += "i32.eqz\n" // logical not
|
|
wat += "br_if 0\n"
|
|
|
|
// condition is true
|
|
wat += condBlockWAT
|
|
|
|
if ifS.ElseBlock != nil {
|
|
wat += "br 1\n" // jump over else block
|
|
}
|
|
|
|
wat += "end\n"
|
|
|
|
if ifS.ElseBlock != nil {
|
|
// condition is false
|
|
elseWAT, err := c.compileBlockWAT(ifS.ElseBlock)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
wat += elseWAT
|
|
wat += "end\n"
|
|
}
|
|
|
|
return wat, nil
|
|
case Statement_WhileLoop:
|
|
while := stmt.Value.(WhileLoopStatement)
|
|
|
|
conditionWAT, err := c.compileExpressionWAT(while.Condition)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
bodyWAT, err := c.compileBlockWAT(while.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
wat := "block\n"
|
|
wat += "loop\n"
|
|
|
|
wat += conditionWAT
|
|
wat += "i32.eqz\n"
|
|
wat += "br_if 1\n"
|
|
|
|
wat += bodyWAT
|
|
wat += "br 0\n"
|
|
|
|
wat += "end\n"
|
|
wat += "end\n"
|
|
|
|
return wat, nil
|
|
}
|
|
|
|
panic("stmt not implemented")
|
|
}
|
|
|
|
func (c *Compiler) compileBlockWAT(block *Block) (string, error) {
|
|
blockWAT := ""
|
|
|
|
for _, stmt := range block.Statements {
|
|
c.CurrentBlock = block
|
|
wat, err := c.compileStatementWAT(stmt, block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
blockWAT += wat
|
|
}
|
|
|
|
return blockWAT, nil
|
|
}
|
|
|
|
func (c *Compiler) compileFunctionWAT(function *ParsedFunction) (string, error) {
|
|
c.CurrentFunction = function
|
|
blockWat, err := c.compileBlockWAT(function.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n"
|
|
|
|
for _, local := range function.Locals {
|
|
if !local.IsParameter {
|
|
continue
|
|
}
|
|
funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
|
|
}
|
|
|
|
returnTypes := []Type{}
|
|
if function.ReturnType != nil {
|
|
returnTypes = []Type{*function.ReturnType}
|
|
if function.ReturnType.Type == Type_Tuple {
|
|
returnTypes = function.ReturnType.Value.(TupleType).Types
|
|
}
|
|
}
|
|
|
|
for _, t := range returnTypes {
|
|
funcWAT += "\t(result " + c.getWATType(t) + ")\n"
|
|
}
|
|
|
|
for _, local := range function.Locals {
|
|
if local.IsParameter {
|
|
continue
|
|
}
|
|
funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
|
|
}
|
|
|
|
return funcWAT + blockWat + ")\n", nil
|
|
}
|
|
|
|
func (c *Compiler) compile() (string, error) {
|
|
module := "(module (memory 0)\n"
|
|
|
|
for _, file := range c.Files {
|
|
for i := range file.Functions {
|
|
wat, err := c.compileFunctionWAT(&file.Functions[i])
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
module += wat
|
|
}
|
|
}
|
|
|
|
module += ")"
|
|
return module, nil
|
|
}
|