diff --git a/backend_wat.go b/backend_wat.go index 9431d9a..1884999 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "strconv" - "strings" "unicode" ) @@ -50,37 +49,37 @@ func safeASCIIIdentifier(identifier string) string { return ascii } -func getTypeCast(primitive PrimitiveType) string { +func getTypeCast(primitive PrimitiveType) Code { switch primitive { case Primitive_I8: - return "i32.extend8_s\n" + return ofInstruction("i32.extend8_s") case Primitive_U8: - return "i32.const 255\ni32.and\n" + return ofInstruction("i32.const 255", "i32.and") case Primitive_I16: - return "i32.extend16_s\n" + return ofInstruction("i32.extend16_s") case Primitive_U16: - return "i32.const 65535\ni32.and\n" + return ofInstruction("i32.const 65535", "i32.and") case Primitive_Bool: - return "i32.const 0\ni32.ne\n" + return ofInstruction("i32.const 0", "i32.ne") } - return "" + return emptyCode() } -func pushConstantNumberWAT(primitive PrimitiveType, value any) string { +func pushConstantNumberWAT(primitive PrimitiveType, value any) Code { switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32: - return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n" + return ofInstruction("i32.const " + strconv.FormatInt(value.(int64), 10)) case Primitive_U8, Primitive_U16, Primitive_U32: - return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n" + return ofInstruction("i32.const " + strconv.FormatUint(value.(uint64), 10)) case Primitive_I64: - return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" + return ofInstruction("i64.const " + strconv.FormatInt(value.(int64), 10)) case Primitive_U64: - return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" + return ofInstruction("i64.const " + strconv.FormatUint(value.(uint64), 10)) case Primitive_F32: - return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" + return ofInstruction("f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32)) case Primitive_F64: - return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" + return ofInstruction("f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64)) } panic(fmt.Sprintf("invalid type passed to pushConstantNumberWAT(): %s", primitive)) @@ -143,38 +142,39 @@ func (c *Compiler) getTypeSizeBytes(t Type) int { panic(fmt.Sprintf("unhandled type in getTypeSizeBytes(): %s", t)) } -func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { +func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (Code, error) { if from == to { - return "", nil + return emptyCode(), nil } if from == Primitive_Bool || to == Primitive_Bool { - return "", errors.New("cannot upcast from or to bool") + return emptyCode(), errors.New("cannot upcast from or to bool") } fromFloat := isFloatingPoint(from) toFloat := isFloatingPoint(to) if fromFloat && toFloat { if to == Primitive_F32 { - return "f32.demote_f64\n", nil + return ofInstruction("f32.demote_f64"), nil } else { - return "f64.promote_f32\n", nil + return ofInstruction("f64.promote_f32"), nil } } if toFloat { - suffix := "" + var suffix string + if isUnsignedInt(to) { suffix = "u" } else { suffix = "s" } - return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil + return ofInstruction(getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix), nil } if fromFloat { - suffix := "" + var suffix string if isUnsignedInt(to) { suffix = "u" @@ -182,11 +182,11 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { suffix = "s" } - return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil + return ofInstruction(getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix), nil } if getBits(from) == getBits(to) { - return "", nil + return emptyCode(), nil } if getPrimitiveWATType(from) == getPrimitiveWATType(to) { @@ -194,36 +194,41 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { return getTypeCast(to), nil } - return "", nil + return emptyCode(), nil } if getBits(from) < 64 && getBits(to) == 64 { - suffix := "" + var suffix string + if isUnsignedInt(from) { suffix = "u" } else { suffix = "s" } - return "i64.extend_i32_" + suffix + "\n", nil + return ofInstruction("i64.extend_i32_" + suffix), nil } - return "i32.wrap_i64\n" + getTypeCast(to), nil + code := ofInstruction("i32.wrap_i64") + code.addAll(getTypeCast(to)) + return code, nil } -func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (string, error) { +func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (Code, error) { lhs := assignment.Lhs exprWAT, err := c.compileExpressionWAT(assignment.Value) if err != nil { - return "", err + return emptyCode(), err } switch lhs.Type { case Expression_VariableReference: ref := lhs.Value.(VariableReferenceExpression) local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) - return exprWAT + "local.tee $" + local + "\n", nil + + exprWAT.add("local.tee $" + local) + return exprWAT, nil case Expression_ArrayAccess: array := lhs.Value.(ArrayAccessExpression) @@ -238,63 +243,72 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio arrayWAT, err := c.compileExpressionWAT(array.Array) if err != nil { - return "", err + return emptyCode(), err } - arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n" + arrayWAT.add("local.set $" + strconv.Itoa(localArray.Index)) indexWAT, err := c.compileExpressionWAT(array.Index) if err != nil { - return "", err + return emptyCode(), err } if !c.Wasm64 { cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32) if err != nil { - return "", err + return emptyCode(), err } - indexWAT += cast + indexWAT.addAll(cast) } - indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n" + indexWAT.add("local.set $" + strconv.Itoa(localIndex.Index)) - wat := arrayWAT + indexWAT + wat := concat(arrayWAT, indexWAT) if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { // Error if index < 0 - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += c.getAddressWATType() + ".const 0\n" - wat += c.getAddressWATType() + ".ge_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + wat.add( + "block", + "local.get $"+strconv.Itoa(localIndex.Index), + c.getAddressWATType()+".const 0", + c.getAddressWATType()+".ge_s", + "br_if 0", + "call $__builtin_panic", + "end", + ) // Error if index >= array length - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" - wat += "i32.load\n" // Load array length - wat += c.getAddressWATType() + ".lt_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + wat.add( + "block", + "local.get $"+strconv.Itoa(localIndex.Index), + "local.get $"+strconv.Itoa(localArray.Index), + "i32.load", // Load array length + c.getAddressWATType()+".lt_s", + "br_if 0", + "call $__builtin_panic", + "end", + ) } elementType := array.Array.ValueType.Value.(ArrayType).ElementType - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n" - wat += c.getAddressWATType() + ".mul\n" - wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" - wat += c.getAddressWATType() + ".add\n" - wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length - wat += c.getAddressWATType() + ".add\n" + wat.add( + "local.get $"+strconv.Itoa(localIndex.Index), + c.getAddressWATType()+".const "+strconv.Itoa(c.getTypeSizeBytes(elementType)), + c.getAddressWATType()+".mul", + "local.get $"+strconv.Itoa(localArray.Index), + c.getAddressWATType()+".add", + c.getAddressWATType()+".const 4", // first 4 bytes = length + c.getAddressWATType()+".add", + ) - wat += exprWAT - wat += "local.tee $" + strconv.Itoa(localElement.Index) + "\n" - wat += c.getWATType(elementType) + ".store\n" // TODO: use load8/load16(_s/u) for smaller types - wat += "local.get $" + strconv.Itoa(localElement.Index) + "\n" + wat.addAll(exprWAT) + + wat.add( + "local.tee $"+strconv.Itoa(localElement.Index), + c.getWATType(elementType)+".store", // TODO: use load8/load16(_s/u) for smaller types + "local.get $"+strconv.Itoa(localElement.Index), + ) return wat, nil case Expression_RawMemoryReference: @@ -309,20 +323,25 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio addrWAT, err := c.compileExpressionWAT(raw.Address) if err != nil { - return "", err + return emptyCode(), err } // TODO: should leave a copy of the stored value on the stack - return addrWAT + exprWAT + - "local.tee $" + strconv.Itoa(local.Index) + "\n" + - c.getWATType(raw.Type) + ".store\n" + - "local.get $" + strconv.Itoa(local.Index) + "\n", nil + code := addrWAT.clone() + code.addAll(exprWAT) + code.add( + "local.tee $"+strconv.Itoa(local.Index), + c.getWATType(raw.Type)+".store", + "local.get $"+strconv.Itoa(local.Index), + ) + + return code, nil } panic("assignment expr not implemented") } -func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWAT string, evaluateToOldValue bool) (string, error) { +func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWAT Code, evaluateToOldValue bool) (Code, error) { switch lhs.Type { case Expression_VariableReference: ref := lhs.Value.(VariableReferenceExpression) @@ -330,25 +349,27 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA exprWAT, err := c.compileExpressionWAT(lhs) if err != nil { - return "", err + return emptyCode(), err } var tmpLocal Local - wat := exprWAT + wat := exprWAT.clone() if evaluateToOldValue { tmpLocal = Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)} c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, tmpLocal) - wat += "local.tee $" + strconv.Itoa(tmpLocal.Index) + "\n" + wat.add("local.tee $" + strconv.Itoa(tmpLocal.Index)) } - wat += updateWAT + wat.addAll(updateWAT) if evaluateToOldValue { - wat += "local.set $" + local + "\n" - wat += "local.get $" + strconv.Itoa(tmpLocal.Index) + "\n" + wat.add( + "local.set $"+local, + "local.get $"+strconv.Itoa(tmpLocal.Index), + ) } else { - wat += "local.tee $" + local + "\n" + wat.add("local.tee $" + local) } return wat, nil @@ -369,29 +390,33 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA addrWAT, err := c.compileExpressionWAT(raw.Address) if err != nil { - return "", err + return emptyCode(), err } wat := addrWAT // dup address - wat += "local.tee $" + strconv.Itoa(localAddress.Index) + "\n" - wat += "local.get $" + strconv.Itoa(localAddress.Index) + "\n" + wat.add( + "local.tee $"+strconv.Itoa(localAddress.Index), + "local.get $"+strconv.Itoa(localAddress.Index), + ) - wat += c.getWATType(raw.Type) + ".load\n" + wat.add(c.getWATType(raw.Type) + ".load") if evaluateToOldValue { - wat += "local.tee $" + strconv.Itoa(localValue.Index) + "\n" + wat.add("local.tee $" + strconv.Itoa(localValue.Index)) } - wat += updateWAT + wat.addAll(updateWAT) if !evaluateToOldValue { - wat += "local.tee $" + strconv.Itoa(localValue.Index) + "\n" + wat.add("local.tee $" + strconv.Itoa(localValue.Index)) } - wat += c.getWATType(raw.Type) + ".store\n" - wat += "local.get $" + strconv.Itoa(localValue.Index) + "\n" + wat.add( + c.getWATType(raw.Type)+".store", + "local.get $"+strconv.Itoa(localValue.Index), + ) return wat, nil } @@ -399,7 +424,7 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA panic("assignment expr not implemented") } -func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { +func (c *Compiler) compileExpressionWAT(expr Expression) (Code, error) { var err error switch expr.Type { @@ -412,12 +437,13 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { watRight, err := c.compileExpressionWAT(ass.Value) if err != nil { - return "", err + return emptyCode(), err } updateOp := c.compileOperationWAT(ass.Operation, ass.Lhs.ValueType.Value.(PrimitiveType)) + watRight.addAll(updateOp) - return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight+updateOp, false) + return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight, false) case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -425,9 +451,9 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil case Literal_Boolean: if lit.Literal.Value.(bool) { - return "i32.const 1\n", nil + return ofInstruction("i32.const 1"), nil } else { - return "i32.const 0\n", nil + return ofInstruction("i32.const 0"), nil } case Literal_String: panic("not implemented") @@ -435,13 +461,17 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { case Expression_VariableReference: ref := expr.Value.(VariableReferenceExpression) - cast := "" + cast := emptyCode() if expr.ValueType.Type == Type_Primitive { // TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) } - return "local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) + "\n" + cast, nil + code := emptyCode() + code.add("local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index)) + code.addAll(cast) + + return code, nil case Expression_Binary: binary := expr.Value.(BinaryExpression) @@ -451,39 +481,39 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { watLeft, err := c.compileExpressionWAT(binary.Left) if err != nil { - return "", err + return emptyCode(), err } watRight, err := c.compileExpressionWAT(binary.Right) if err != nil { - return "", err + return emptyCode(), err } op := c.compileOperationWAT(binary.Operation, operandType) - return watLeft + watRight + op + getTypeCast(exprType), nil + return concat(watLeft, watRight, op, getTypeCast(exprType)), nil case Expression_Tuple: tuple := expr.Value.(TupleExpression) - wat := "" + wat := emptyCode() for _, member := range tuple.Members { memberWAT, err := c.compileExpressionWAT(member) if err != nil { - return "", err + return emptyCode(), err } - wat += memberWAT + wat.addAll(memberWAT) } return wat, nil case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) - wat := "" + wat := emptyCode() if fc.Parameters != nil { wat, err = c.compileExpressionWAT(*fc.Parameters) if err != nil { - return "", err + return emptyCode(), err } } @@ -492,39 +522,40 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { if !c.Wasm64 { cast, err := castPrimitiveWAT(Primitive_U64, Primitive_U32) if err != nil { - return "", err + return emptyCode(), err } - wat += cast + wat.addAll(cast) } - wat += "memory.grow\n" + wat.add("memory.grow") if !c.Wasm64 { cast, err := castPrimitiveWAT(Primitive_I32, Primitive_I64) if err != nil { - return "", err + return emptyCode(), err } - wat += cast + wat.addAll(cast) } return wat, nil case BUILTIN_MEMORY_SIZE: - wat += "memory.size\n" + wat.add("memory.size") if !c.Wasm64 { cast, err := castPrimitiveWAT(Primitive_U32, Primitive_U64) if err != nil { - return "", err + return emptyCode(), err } - wat += cast + wat.addAll(cast) } return wat, nil default: - return wat + "call $" + fc.Function + "\n", nil + wat.add("call $" + fc.Function) + return wat, nil } case Expression_Unary: unary := expr.Value.(UnaryExpression) @@ -532,35 +563,46 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { wat, err := c.compileExpressionWAT(unary.Value) if err != nil { - return "", err + return emptyCode(), err } watType := getPrimitiveWATType(exprType) switch unary.Operation { case UnaryOperation_Negate: if isFloatingPoint(exprType) { - return wat + watType + ".neg\n", nil + wat.add(watType + ".neg") + return wat, nil } else { - return watType + ".const 0\n" + wat + watType + ".sub\n" + getTypeCast(exprType), nil + code := emptyCode() + code.add(watType + ".const 0") + code.addAll(wat) + code.add(watType + ".sub") + code.addAll(getTypeCast(exprType)) + return code, nil } case UnaryOperation_Nop: return wat, nil case UnaryOperation_BitwiseNot: if getBits(exprType) == 64 { - return wat + watType + ".const 0xFFFFFFFFFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil + wat.add(watType + ".const 0xFFFFFFFFFFFFFFFF") } else { - return wat + watType + ".const 0xFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil + wat.add(watType + ".const 0xFFFFFFFF") } + + wat.add(watType + ".xor") + wat.addAll(getTypeCast(exprType)) + return wat, nil case UnaryOperation_LogicalNot: - return wat + "i32.eqz\n", nil + wat.add("i32.eqz") + return wat, nil case UnaryOperation_PreIncrement, UnaryOperation_PreDecrement, UnaryOperation_PostIncrement, UnaryOperation_PostDecrement: valueType := c.getWATType(*unary.Value.ValueType) - updateWAT := valueType + ".const 1\n" + updateWAT := ofInstruction(valueType + ".const 1") if unary.Operation == UnaryOperation_PreIncrement || unary.Operation == UnaryOperation_PostIncrement { - updateWAT += valueType + ".add\n" + updateWAT.add(valueType + ".add") } else { - updateWAT += valueType + ".sub\n" + updateWAT.add(valueType + ".sub") } return c.compileAssignmentUpdateExpressionWAT(unary.Value, updateWAT, unary.Operation == UnaryOperation_PostIncrement || unary.Operation == UnaryOperation_PostDecrement) @@ -570,7 +612,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { wat, err := c.compileExpressionWAT(cast.Value) if err != nil { - return "", err + return emptyCode(), err } // TODO: fine, as it is currently only allowed for primitive types @@ -578,20 +620,20 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { toType := cast.Type.Value.(PrimitiveType) castWAT, err := castPrimitiveWAT(fromType, toType) if err != nil { - return "", err + return emptyCode(), err } - return wat + castWAT, nil + return concat(wat, castWAT), nil case Expression_RawMemoryReference: raw := expr.Value.(RawMemoryReferenceExpression) wat, err := c.compileExpressionWAT(raw.Address) if err != nil { - return "", err + return emptyCode(), err } if raw.Type.Type == Type_Primitive { - wat += c.getWATType(raw.Type) + ".load\n" // TODO: use load8/load16(_s/u) for smaller types + wat.add(c.getWATType(raw.Type) + ".load") // TODO: use load8/load16(_s/u) for smaller types } return wat, nil @@ -606,60 +648,66 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { arrayWAT, err := c.compileExpressionWAT(array.Array) if err != nil { - return "", err + return emptyCode(), err } - arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n" + arrayWAT.add("local.set $" + strconv.Itoa(localArray.Index)) indexWAT, err := c.compileExpressionWAT(array.Index) if err != nil { - return "", err + return emptyCode(), err } if !c.Wasm64 { cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32) if err != nil { - return "", err + return emptyCode(), err } - indexWAT += cast + indexWAT.addAll(cast) } - indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n" + indexWAT.add("local.set $" + strconv.Itoa(localIndex.Index)) - wat := arrayWAT + indexWAT + wat := concat(arrayWAT, indexWAT) if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { // Error if index < 0 - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += c.getAddressWATType() + ".const 0\n" - wat += c.getAddressWATType() + ".ge_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + wat.add( + "block", + "local.get $"+strconv.Itoa(localIndex.Index), + c.getAddressWATType()+".const 0", + c.getAddressWATType()+".ge_s", + "br_if 0", + "call $__builtin_panic", + "end", + ) // Error if index >= array length - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" - wat += "i32.load\n" // Load array length - wat += c.getAddressWATType() + ".lt_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + wat.add( + "block", + "local.get $"+strconv.Itoa(localIndex.Index), + "local.get $"+strconv.Itoa(localArray.Index), + "i32.load", // Load array length + c.getAddressWATType()+".lt_s", + "br_if 0", + "call $__builtin_panic", + "end", + ) } elementType := array.Array.ValueType.Value.(ArrayType).ElementType - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n" - wat += c.getAddressWATType() + ".mul\n" - wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" - wat += c.getAddressWATType() + ".add\n" - wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length - wat += c.getAddressWATType() + ".add\n" + wat.add( + "local.get $"+strconv.Itoa(localIndex.Index), + c.getAddressWATType()+".const "+strconv.Itoa(c.getTypeSizeBytes(elementType)), + c.getAddressWATType()+".mul", + "local.get $"+strconv.Itoa(localArray.Index), + c.getAddressWATType()+".add", + c.getAddressWATType()+".const 4", // first 4 bytes = length + c.getAddressWATType()+".add", + ) - wat += c.getWATType(elementType) + ".load\n" // TODO: use load8/load16(_s/u) for smaller types + wat.add(c.getWATType(elementType) + ".load") // TODO: use load8/load16(_s/u) for smaller types return wat, nil } @@ -667,10 +715,10 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { panic("expr not implemented") } -func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) string { - op := "" +func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) Code { + var op string - suffix := "" + var suffix string if isUnsignedInt(operandType) { suffix = "u" } else { @@ -679,41 +727,41 @@ func (c *Compiler) compileOperationWAT(operation Operation, operandType Primitiv switch operation { case Operation_Add: - op = getPrimitiveWATType(operandType) + ".add\n" + op = getPrimitiveWATType(operandType) + ".add" case Operation_Sub: - op = getPrimitiveWATType(operandType) + ".sub\n" + op = getPrimitiveWATType(operandType) + ".sub" case Operation_Mul: - op = getPrimitiveWATType(operandType) + ".mul\n" + op = getPrimitiveWATType(operandType) + ".mul" case Operation_Div: - op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".div_" + suffix case Operation_Mod: - op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".rem_" + suffix case Operation_Greater: - op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".gt_" + suffix case Operation_Less: - op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".lt_" + suffix case Operation_GreaterEquals: - op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".ge_" + suffix case Operation_LessEquals: - op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".le_" + suffix case Operation_NotEquals: - op = getPrimitiveWATType(operandType) + ".ne\n" + op = getPrimitiveWATType(operandType) + ".ne" case Operation_Equals: - op = getPrimitiveWATType(operandType) + ".eq\n" + op = getPrimitiveWATType(operandType) + ".eq" default: panic("operation not implemented") } - return op + return ofInstruction(op) } -func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { +func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (Code, error) { switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) wat, err := c.compileExpressionWAT(expr.Expression) if err != nil { - return "", err + return emptyCode(), err } numItems := 0 @@ -725,85 +773,91 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er } } - return wat + strings.Repeat("drop\n", numItems), nil + for _ = range numItems { + wat.add("drop") + } + + return wat, nil case Statement_Block: block := stmt.Value.(BlockStatement) wat, err := c.compileBlockWAT(block.Block) if err != nil { - return "", err + return emptyCode(), err } return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) - wat := "" + wat := emptyCode() if ret.Value != nil { valueWAT, err := c.compileExpressionWAT(*ret.Value) if err != nil { - return "", err + return emptyCode(), err } - wat += valueWAT + wat.addAll(valueWAT) } - return wat + "return\n", nil + wat.add("return") + return wat, nil case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer == nil { - return "", nil + return emptyCode(), nil } wat, err := c.compileExpressionWAT(*dlv.Initializer) if err != nil { - return "", err + return emptyCode(), err } - return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil + wat.add("local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index)) + return wat, nil case Statement_If: ifS := stmt.Value.(IfStatement) conditionWAT, err := c.compileExpressionWAT(ifS.Condition) if err != nil { - return "", err + return emptyCode(), err } condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock) if err != nil { - return "", err + return emptyCode(), err } - wat := "" + wat := emptyCode() if ifS.ElseBlock != nil { - wat += "block\n" + wat.add("block") } // condition - wat += "block\n" + wat.add("block") - wat += conditionWAT - wat += "i32.eqz\n" // logical not - wat += "br_if 0\n" + wat.addAll(conditionWAT) + wat.add("i32.eqz") // logical not + wat.add("br_if 0") // condition is true - wat += condBlockWAT + wat.addAll(condBlockWAT) if ifS.ElseBlock != nil { - wat += "br 1\n" // jump over else block + wat.add("br 1") // jump over else block } - wat += "end\n" + wat.add("end") if ifS.ElseBlock != nil { // condition is false elseWAT, err := c.compileBlockWAT(ifS.ElseBlock) if err != nil { - return "", err + return emptyCode(), err } - wat += elseWAT - wat += "end\n" + wat.addAll(elseWAT) + wat.add("end") } return wat, nil @@ -812,26 +866,27 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er conditionWAT, err := c.compileExpressionWAT(while.Condition) if err != nil { - return "", err + return emptyCode(), err } bodyWAT, err := c.compileBlockWAT(while.Body) if err != nil { - return "", err + return emptyCode(), err } - wat := "block\n" - wat += "loop\n" + wat := emptyCode() + wat.add("block") + wat.add("loop") - wat += conditionWAT - wat += "i32.eqz\n" - wat += "br_if 1\n" + wat.addAll(conditionWAT) + wat.add("i32.eqz") + wat.add("br_if 1") - wat += bodyWAT - wat += "br 0\n" + wat.addAll(bodyWAT) + wat.add("br 0") - wat += "end\n" - wat += "end\n" + wat.add("end") + wat.add("end") return wat, nil } @@ -839,17 +894,17 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er panic("stmt not implemented") } -func (c *Compiler) compileBlockWAT(block *Block) (string, error) { - blockWAT := "" +func (c *Compiler) compileBlockWAT(block *Block) (Code, error) { + blockWAT := emptyCode() for _, stmt := range block.Statements { c.CurrentBlock = block wat, err := c.compileStatementWAT(stmt, block) if err != nil { - return "", err + return emptyCode(), err } - blockWAT += wat + blockWAT.addAll(wat) } return blockWAT, nil @@ -890,14 +945,15 @@ func (c *Compiler) compileFunctionWAT(function *ParsedFunction) (string, error) funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } - return funcWAT + blockWat + ")\n", nil + return funcWAT + blockWat.toString() + ")\n", nil } func (c *Compiler) compile() (string, error) { - module := "(module (memory $memory 0) (export \"memory\" (memory $memory))\n" + module := "(module\n" + module += "(memory $memory 0) (export \"memory\" (memory $memory))\n" module += "(func $__builtin_panic\n" - module += "unreachable\n" + module += "\tunreachable\n" module += ")\n" for _, file := range c.Files { diff --git a/compiler.go b/compiler.go new file mode 100644 index 0000000..5237569 --- /dev/null +++ b/compiler.go @@ -0,0 +1,66 @@ +package main + +import "strings" + +type Code struct { + Instructions []string +} + +func (code *Code) add(instructions ...string) { + code.Instructions = append(code.Instructions, instructions...) +} + +func (code *Code) addAll(other Code) { + code.Instructions = append(code.Instructions, other.Instructions...) +} + +func ofInstruction(instructions ...string) Code { + return Code{Instructions: instructions} +} + +func emptyCode() Code { + return Code{Instructions: []string{}} +} + +func clone(code Code) Code { + newInstrs := make([]string, len(code.Instructions)) + copy(newInstrs, code.Instructions) + return Code{Instructions: newInstrs} +} + +func concat(code Code, other ...Code) Code { + newCode := clone(code) + + for _, o := range other { + newCode.addAll(o) + } + + return newCode +} + +func (code Code) clone() Code { + return clone(code) +} + +func (code *Code) isEmpty() bool { + return len(code.Instructions) == 0 +} + +func (code *Code) toString() string { + text := "" + + indent := 1 + for _, instr := range code.Instructions { + if instr == "end" { + indent-- + } + + text += strings.Repeat("\t", indent) + instr + "\n" + + if instr == "block" || instr == "loop" { + indent++ + } + } + + return text +} diff --git a/go.mod b/go.mod index da57197..3fda6b7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module git.cringe-studios.com/mr/elysium -go 1.21.7 +go 1.23.2