package main import ( "errors" "strconv" "strings" "unicode" ) type Compiler struct { Files []*ParsedFile Wasm64 bool } func getPrimitiveWATType(primitive PrimitiveType) string { switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: return "i32" case Primitive_I64, Primitive_U64: return "i64" case Primitive_F32: return "f32" case Primitive_F64: return "f64" case Primitive_Bool: return "i32" } panic("unhandled type") } func safeASCIIIdentifier(identifier string) string { ascii := "" for _, rune := range identifier { if rune < unicode.MaxASCII && (unicode.IsLetter(rune) || unicode.IsDigit(rune)) { ascii += string(rune) continue } ascii += "$" + strconv.Itoa(int(rune)) } return ascii } func getTypeCast(primitive PrimitiveType) string { switch primitive { case Primitive_I8: return "i32.extend8_s\n" case Primitive_U8: return "i32.const 255\ni32.and\n" case Primitive_I16: return "i32.extend16_s\n" case Primitive_U16: return "i32.const 65535\ni32.and\n" case Primitive_Bool: return "i32.const 1\ni32.and\n" } return "" } func pushConstantNumberWAT(primitive PrimitiveType, value any) string { switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32: return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n" case Primitive_U8, Primitive_U16, Primitive_U32: return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n" case Primitive_I64: return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" case Primitive_U64: return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" case Primitive_F32: return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" case Primitive_F64: return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" } panic("invalid type") } func (c *Compiler) getWATType(t Type) string { switch t.Type { case Type_Primitive: return getPrimitiveWATType(t.Value.(PrimitiveType)) case Type_Named, Type_Array: if c.Wasm64 { return "i64" } else { return "i32" } case Type_Tuple: panic("tuple type passed to getWATType()") } panic("type not implemented in getWATType()") } func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { if from == to { return "", nil } if from == Primitive_Bool || to == Primitive_Bool { return "", errors.New("cannot upcast from or to bool") } fromFloat := isFloatingPoint(from) toFloat := isFloatingPoint(to) if fromFloat && toFloat { if to == Primitive_F32 { return "f32.demote_f64\n", nil } else { return "f64.promote_f32\n", nil } } if toFloat { suffix := "" if isUnsignedInt(to) { suffix = "u" } else { suffix = "s" } return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil } if fromFloat { suffix := "" if isUnsignedInt(to) { suffix = "u" } else { suffix = "s" } return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil } if getBits(from) == getBits(to) { return "", nil } if getPrimitiveWATType(from) == getPrimitiveWATType(to) { if getBits(to) < getBits(from) { return getTypeCast(to), nil } return "", nil } if getBits(from) < 64 && getBits(to) == 64 { suffix := "" if isUnsignedInt(from) { suffix = "u" } else { suffix = "s" } return "i64.extend_i32_" + suffix + "\n", nil } return "i32.wrap_i64\n" + getTypeCast(to), nil } func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, error) { var err error switch expr.Type { case Expression_Assignment: ass := expr.Value.(AssignmentExpression) exprWAT, err := c.compileExpressionWAT(ass.Value, block) if err != nil { return "", err } local := strconv.Itoa(block.Locals[ass.Variable].Index) getLocal := "local.get $" + local + "\n" setLocal := "local.set $" + local + "\n" return exprWAT + setLocal + getLocal, nil case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { case Literal_Number: return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil case Literal_Boolean: if lit.Literal.Value.(bool) { return "i32.const 1\n", nil } else { return "i32.const 0\n", nil } case Literal_String: panic("not implemented") } case Expression_VariableReference: ref := expr.Value.(VariableReferenceExpression) cast := "" if expr.ValueType.Type == Type_Primitive { // TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) } return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil case Expression_Binary: binary := expr.Value.(BinaryExpression) // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings operandType := binary.Left.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType) watLeft, err := c.compileExpressionWAT(binary.Left, block) if err != nil { return "", err } watRight, err := c.compileExpressionWAT(binary.Right, block) if err != nil { return "", err } op := "" suffix := "" if isUnsignedInt(operandType) { suffix = "u" } else { suffix = "s" } switch binary.Operation { case Operation_Add: op = getPrimitiveWATType(operandType) + ".add\n" case Operation_Sub: op = getPrimitiveWATType(operandType) + ".sub\n" case Operation_Mul: op = getPrimitiveWATType(operandType) + ".mul\n" case Operation_Div: op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" case Operation_Mod: op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" case Operation_Greater: op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" case Operation_Less: op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" case Operation_GreaterEquals: op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" case Operation_LessEquals: op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" case Operation_NotEquals: op = getPrimitiveWATType(operandType) + ".ne\n" case Operation_Equals: op = getPrimitiveWATType(operandType) + ".eq\n" default: panic("operation not implemented") } return watLeft + watRight + op + getTypeCast(exprType), nil case Expression_Tuple: tuple := expr.Value.(TupleExpression) wat := "" for _, member := range tuple.Members { memberWAT, err := c.compileExpressionWAT(member, block) if err != nil { return "", err } wat += memberWAT } return wat, nil case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) wat := "" if fc.Parameters != nil { wat, err = c.compileExpressionWAT(*fc.Parameters, block) if err != nil { return "", err } } return wat + "call $" + fc.Function + "\n", nil case Expression_Negate: neg := expr.Value.(NegateExpression) exprType := expr.ValueType.Value.(PrimitiveType) wat, err := c.compileExpressionWAT(neg.Value, block) if err != nil { return "", err } watType := getPrimitiveWATType(exprType) if isSignedInt(exprType) || isUnsignedInt(exprType) { return watType + ".const 0\n" + wat + watType + ".sub\n", nil } if isFloatingPoint(exprType) { return watType + ".neg\n", nil } case Expression_Cast: cast := expr.Value.(CastExpression) wat, err := c.compileExpressionWAT(cast.Value, block) if err != nil { return "", err } // TODO: fine, as it is currently only allowed for primitive types fromType := cast.Value.ValueType.Value.(PrimitiveType) toType := cast.Type.Value.(PrimitiveType) castWAT, err := castPrimitiveWAT(fromType, toType) if err != nil { return "", err } return wat + castWAT, nil } panic("expr not implemented") } func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) wat, err := c.compileExpressionWAT(expr.Expression, block) if err != nil { return "", err } numItems := 0 if expr.Expression.ValueType != nil { numItems = 1 if expr.Expression.ValueType.Type == Type_Tuple { numItems = len(expr.Expression.ValueType.Value.(TupleType).Types) } } return wat + strings.Repeat("drop\n", numItems), nil case Statement_Block: block := stmt.Value.(BlockStatement) wat, err := c.compileBlockWAT(block.Block) if err != nil { return "", err } return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) wat, err := c.compileExpressionWAT(*ret.Value, block) if err != nil { return "", err } return wat + "return\n", nil case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer == nil { return "", nil } wat, err := c.compileExpressionWAT(*dlv.Initializer, block) if err != nil { return "", err } return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil case Statement_If: ifS := stmt.Value.(IfStatement) conditionWAT, err := c.compileExpressionWAT(ifS.Condition, block) if err != nil { return "", err } condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock) if err != nil { return "", err } wat := "" if ifS.ElseBlock != nil { wat += "block\n" } // condition wat += "block\n" wat += conditionWAT wat += "i32.eqz\n" // logical not wat += "br_if 0\n" // condition is true wat += condBlockWAT if ifS.ElseBlock != nil { wat += "br 1\n" // jump over else block } wat += "end\n" if ifS.ElseBlock != nil { // condition is false elseWAT, err := c.compileBlockWAT(ifS.ElseBlock) if err != nil { return "", err } wat += elseWAT wat += "end\n" } return wat, nil } panic("stmt not implemented") } func (c *Compiler) compileBlockWAT(block *Block) (string, error) { blockWAT := "" for _, stmt := range block.Statements { wat, err := c.compileStatementWAT(stmt, block) if err != nil { return "", err } blockWAT += wat } return blockWAT, nil } func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) { funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n" for _, local := range function.Locals { if !local.IsParameter { continue } funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } returnTypes := []Type{} if function.ReturnType != nil { returnTypes = []Type{*function.ReturnType} if function.ReturnType.Type == Type_Tuple { returnTypes = function.ReturnType.Value.(TupleType).Types } } for _, t := range returnTypes { funcWAT += "\t(result " + c.getWATType(t) + ")\n" } for _, local := range function.Locals { if local.IsParameter { continue } funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } wat, err := c.compileBlockWAT(function.Body) if err != nil { return "", err } funcWAT += wat return funcWAT + ")\n", nil } func (c *Compiler) compile() (string, error) { module := "(module\n" for _, file := range c.Files { for _, function := range file.Functions { wat, err := c.compileFunctionWAT(function) if err != nil { return "", err } module += wat } } module += ")" return module, nil }