Improve code handling

This commit is contained in:
MrLetsplay 2024-11-16 22:41:26 +01:00
parent 61d8694a02
commit b5c42a6690
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
3 changed files with 336 additions and 214 deletions

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"unicode" "unicode"
) )
@ -50,37 +49,37 @@ func safeASCIIIdentifier(identifier string) string {
return ascii return ascii
} }
func getTypeCast(primitive PrimitiveType) string { func getTypeCast(primitive PrimitiveType) Code {
switch primitive { switch primitive {
case Primitive_I8: case Primitive_I8:
return "i32.extend8_s\n" return ofInstruction("i32.extend8_s")
case Primitive_U8: case Primitive_U8:
return "i32.const 255\ni32.and\n" return ofInstruction("i32.const 255", "i32.and")
case Primitive_I16: case Primitive_I16:
return "i32.extend16_s\n" return ofInstruction("i32.extend16_s")
case Primitive_U16: case Primitive_U16:
return "i32.const 65535\ni32.and\n" return ofInstruction("i32.const 65535", "i32.and")
case Primitive_Bool: case Primitive_Bool:
return "i32.const 0\ni32.ne\n" return ofInstruction("i32.const 0", "i32.ne")
} }
return "" return emptyCode()
} }
func pushConstantNumberWAT(primitive PrimitiveType, value any) string { func pushConstantNumberWAT(primitive PrimitiveType, value any) Code {
switch primitive { switch primitive {
case Primitive_I8, Primitive_I16, Primitive_I32: case Primitive_I8, Primitive_I16, Primitive_I32:
return "i32.const " + strconv.FormatInt(value.(int64), 10) + "\n" return ofInstruction("i32.const " + strconv.FormatInt(value.(int64), 10))
case Primitive_U8, Primitive_U16, Primitive_U32: case Primitive_U8, Primitive_U16, Primitive_U32:
return "i32.const " + strconv.FormatUint(value.(uint64), 10) + "\n" return ofInstruction("i32.const " + strconv.FormatUint(value.(uint64), 10))
case Primitive_I64: case Primitive_I64:
return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" return ofInstruction("i64.const " + strconv.FormatInt(value.(int64), 10))
case Primitive_U64: case Primitive_U64:
return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" return ofInstruction("i64.const " + strconv.FormatUint(value.(uint64), 10))
case Primitive_F32: case Primitive_F32:
return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" return ofInstruction("f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32))
case Primitive_F64: case Primitive_F64:
return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" return ofInstruction("f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64))
} }
panic(fmt.Sprintf("invalid type passed to pushConstantNumberWAT(): %s", primitive)) panic(fmt.Sprintf("invalid type passed to pushConstantNumberWAT(): %s", primitive))
@ -143,38 +142,39 @@ func (c *Compiler) getTypeSizeBytes(t Type) int {
panic(fmt.Sprintf("unhandled type in getTypeSizeBytes(): %s", t)) panic(fmt.Sprintf("unhandled type in getTypeSizeBytes(): %s", t))
} }
func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (Code, error) {
if from == to { if from == to {
return "", nil return emptyCode(), nil
} }
if from == Primitive_Bool || to == Primitive_Bool { if from == Primitive_Bool || to == Primitive_Bool {
return "", errors.New("cannot upcast from or to bool") return emptyCode(), errors.New("cannot upcast from or to bool")
} }
fromFloat := isFloatingPoint(from) fromFloat := isFloatingPoint(from)
toFloat := isFloatingPoint(to) toFloat := isFloatingPoint(to)
if fromFloat && toFloat { if fromFloat && toFloat {
if to == Primitive_F32 { if to == Primitive_F32 {
return "f32.demote_f64\n", nil return ofInstruction("f32.demote_f64"), nil
} else { } else {
return "f64.promote_f32\n", nil return ofInstruction("f64.promote_f32"), nil
} }
} }
if toFloat { if toFloat {
suffix := "" var suffix string
if isUnsignedInt(to) { if isUnsignedInt(to) {
suffix = "u" suffix = "u"
} else { } else {
suffix = "s" suffix = "s"
} }
return getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil return ofInstruction(getPrimitiveWATType(to) + ".convert_" + getPrimitiveWATType(from) + "_" + suffix), nil
} }
if fromFloat { if fromFloat {
suffix := "" var suffix string
if isUnsignedInt(to) { if isUnsignedInt(to) {
suffix = "u" suffix = "u"
@ -182,11 +182,11 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
suffix = "s" suffix = "s"
} }
return getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix + "\n", nil return ofInstruction(getPrimitiveWATType(to) + ".trunc_" + getPrimitiveWATType(from) + "_" + suffix), nil
} }
if getBits(from) == getBits(to) { if getBits(from) == getBits(to) {
return "", nil return emptyCode(), nil
} }
if getPrimitiveWATType(from) == getPrimitiveWATType(to) { if getPrimitiveWATType(from) == getPrimitiveWATType(to) {
@ -194,36 +194,41 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
return getTypeCast(to), nil return getTypeCast(to), nil
} }
return "", nil return emptyCode(), nil
} }
if getBits(from) < 64 && getBits(to) == 64 { if getBits(from) < 64 && getBits(to) == 64 {
suffix := "" var suffix string
if isUnsignedInt(from) { if isUnsignedInt(from) {
suffix = "u" suffix = "u"
} else { } else {
suffix = "s" suffix = "s"
} }
return "i64.extend_i32_" + suffix + "\n", nil return ofInstruction("i64.extend_i32_" + suffix), nil
} }
return "i32.wrap_i64\n" + getTypeCast(to), nil code := ofInstruction("i32.wrap_i64")
code.addAll(getTypeCast(to))
return code, nil
} }
func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (string, error) { func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (Code, error) {
lhs := assignment.Lhs lhs := assignment.Lhs
exprWAT, err := c.compileExpressionWAT(assignment.Value) exprWAT, err := c.compileExpressionWAT(assignment.Value)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
switch lhs.Type { switch lhs.Type {
case Expression_VariableReference: case Expression_VariableReference:
ref := lhs.Value.(VariableReferenceExpression) ref := lhs.Value.(VariableReferenceExpression)
local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index)
return exprWAT + "local.tee $" + local + "\n", nil
exprWAT.add("local.tee $" + local)
return exprWAT, nil
case Expression_ArrayAccess: case Expression_ArrayAccess:
array := lhs.Value.(ArrayAccessExpression) array := lhs.Value.(ArrayAccessExpression)
@ -238,63 +243,72 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio
arrayWAT, err := c.compileExpressionWAT(array.Array) arrayWAT, err := c.compileExpressionWAT(array.Array)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n" arrayWAT.add("local.set $" + strconv.Itoa(localArray.Index))
indexWAT, err := c.compileExpressionWAT(array.Index) indexWAT, err := c.compileExpressionWAT(array.Index)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
if !c.Wasm64 { if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32) cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
indexWAT += cast indexWAT.addAll(cast)
} }
indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n" indexWAT.add("local.set $" + strconv.Itoa(localIndex.Index))
wat := arrayWAT + indexWAT wat := concat(arrayWAT, indexWAT)
if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok {
// Error if index < 0 // Error if index < 0
wat += "block\n" wat.add(
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" "block",
wat += c.getAddressWATType() + ".const 0\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += c.getAddressWATType() + ".ge_s\n" c.getAddressWATType()+".const 0",
wat += "br_if 0\n" c.getAddressWATType()+".ge_s",
wat += "call $__builtin_panic\n" "br_if 0",
wat += "end\n" "call $__builtin_panic",
"end",
)
// Error if index >= array length // Error if index >= array length
wat += "block\n" wat.add(
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" "block",
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += "i32.load\n" // Load array length "local.get $"+strconv.Itoa(localArray.Index),
wat += c.getAddressWATType() + ".lt_s\n" "i32.load", // Load array length
wat += "br_if 0\n" c.getAddressWATType()+".lt_s",
wat += "call $__builtin_panic\n" "br_if 0",
wat += "end\n" "call $__builtin_panic",
"end",
)
} }
elementType := array.Array.ValueType.Value.(ArrayType).ElementType elementType := array.Array.ValueType.Value.(ArrayType).ElementType
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" wat.add(
wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += c.getAddressWATType() + ".mul\n" c.getAddressWATType()+".const "+strconv.Itoa(c.getTypeSizeBytes(elementType)),
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" c.getAddressWATType()+".mul",
wat += c.getAddressWATType() + ".add\n" "local.get $"+strconv.Itoa(localArray.Index),
wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length c.getAddressWATType()+".add",
wat += c.getAddressWATType() + ".add\n" c.getAddressWATType()+".const 4", // first 4 bytes = length
c.getAddressWATType()+".add",
)
wat += exprWAT wat.addAll(exprWAT)
wat += "local.tee $" + strconv.Itoa(localElement.Index) + "\n"
wat += c.getWATType(elementType) + ".store\n" // TODO: use load8/load16(_s/u) for smaller types wat.add(
wat += "local.get $" + strconv.Itoa(localElement.Index) + "\n" "local.tee $"+strconv.Itoa(localElement.Index),
c.getWATType(elementType)+".store", // TODO: use load8/load16(_s/u) for smaller types
"local.get $"+strconv.Itoa(localElement.Index),
)
return wat, nil return wat, nil
case Expression_RawMemoryReference: case Expression_RawMemoryReference:
@ -309,20 +323,25 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio
addrWAT, err := c.compileExpressionWAT(raw.Address) addrWAT, err := c.compileExpressionWAT(raw.Address)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
// TODO: should leave a copy of the stored value on the stack // TODO: should leave a copy of the stored value on the stack
return addrWAT + exprWAT + code := addrWAT.clone()
"local.tee $" + strconv.Itoa(local.Index) + "\n" + code.addAll(exprWAT)
c.getWATType(raw.Type) + ".store\n" + code.add(
"local.get $" + strconv.Itoa(local.Index) + "\n", nil "local.tee $"+strconv.Itoa(local.Index),
c.getWATType(raw.Type)+".store",
"local.get $"+strconv.Itoa(local.Index),
)
return code, nil
} }
panic("assignment expr not implemented") panic("assignment expr not implemented")
} }
func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWAT string, evaluateToOldValue bool) (string, error) { func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWAT Code, evaluateToOldValue bool) (Code, error) {
switch lhs.Type { switch lhs.Type {
case Expression_VariableReference: case Expression_VariableReference:
ref := lhs.Value.(VariableReferenceExpression) ref := lhs.Value.(VariableReferenceExpression)
@ -330,25 +349,27 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA
exprWAT, err := c.compileExpressionWAT(lhs) exprWAT, err := c.compileExpressionWAT(lhs)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
var tmpLocal Local var tmpLocal Local
wat := exprWAT wat := exprWAT.clone()
if evaluateToOldValue { if evaluateToOldValue {
tmpLocal = Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)} tmpLocal = Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, tmpLocal) c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, tmpLocal)
wat += "local.tee $" + strconv.Itoa(tmpLocal.Index) + "\n" wat.add("local.tee $" + strconv.Itoa(tmpLocal.Index))
} }
wat += updateWAT wat.addAll(updateWAT)
if evaluateToOldValue { if evaluateToOldValue {
wat += "local.set $" + local + "\n" wat.add(
wat += "local.get $" + strconv.Itoa(tmpLocal.Index) + "\n" "local.set $"+local,
"local.get $"+strconv.Itoa(tmpLocal.Index),
)
} else { } else {
wat += "local.tee $" + local + "\n" wat.add("local.tee $" + local)
} }
return wat, nil return wat, nil
@ -369,29 +390,33 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA
addrWAT, err := c.compileExpressionWAT(raw.Address) addrWAT, err := c.compileExpressionWAT(raw.Address)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat := addrWAT wat := addrWAT
// dup address // dup address
wat += "local.tee $" + strconv.Itoa(localAddress.Index) + "\n" wat.add(
wat += "local.get $" + strconv.Itoa(localAddress.Index) + "\n" "local.tee $"+strconv.Itoa(localAddress.Index),
"local.get $"+strconv.Itoa(localAddress.Index),
)
wat += c.getWATType(raw.Type) + ".load\n" wat.add(c.getWATType(raw.Type) + ".load")
if evaluateToOldValue { if evaluateToOldValue {
wat += "local.tee $" + strconv.Itoa(localValue.Index) + "\n" wat.add("local.tee $" + strconv.Itoa(localValue.Index))
} }
wat += updateWAT wat.addAll(updateWAT)
if !evaluateToOldValue { if !evaluateToOldValue {
wat += "local.tee $" + strconv.Itoa(localValue.Index) + "\n" wat.add("local.tee $" + strconv.Itoa(localValue.Index))
} }
wat += c.getWATType(raw.Type) + ".store\n" wat.add(
wat += "local.get $" + strconv.Itoa(localValue.Index) + "\n" c.getWATType(raw.Type)+".store",
"local.get $"+strconv.Itoa(localValue.Index),
)
return wat, nil return wat, nil
} }
@ -399,7 +424,7 @@ func (c *Compiler) compileAssignmentUpdateExpressionWAT(lhs Expression, updateWA
panic("assignment expr not implemented") panic("assignment expr not implemented")
} }
func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { func (c *Compiler) compileExpressionWAT(expr Expression) (Code, error) {
var err error var err error
switch expr.Type { switch expr.Type {
@ -412,12 +437,13 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
watRight, err := c.compileExpressionWAT(ass.Value) watRight, err := c.compileExpressionWAT(ass.Value)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
updateOp := c.compileOperationWAT(ass.Operation, ass.Lhs.ValueType.Value.(PrimitiveType)) updateOp := c.compileOperationWAT(ass.Operation, ass.Lhs.ValueType.Value.(PrimitiveType))
watRight.addAll(updateOp)
return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight+updateOp, false) return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight, false)
case Expression_Literal: case Expression_Literal:
lit := expr.Value.(LiteralExpression) lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type { switch lit.Literal.Type {
@ -425,9 +451,9 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil return pushConstantNumberWAT(lit.Literal.Primitive, lit.Literal.Value), nil
case Literal_Boolean: case Literal_Boolean:
if lit.Literal.Value.(bool) { if lit.Literal.Value.(bool) {
return "i32.const 1\n", nil return ofInstruction("i32.const 1"), nil
} else { } else {
return "i32.const 0\n", nil return ofInstruction("i32.const 0"), nil
} }
case Literal_String: case Literal_String:
panic("not implemented") panic("not implemented")
@ -435,13 +461,17 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
case Expression_VariableReference: case Expression_VariableReference:
ref := expr.Value.(VariableReferenceExpression) ref := expr.Value.(VariableReferenceExpression)
cast := "" cast := emptyCode()
if expr.ValueType.Type == Type_Primitive { if expr.ValueType.Type == Type_Primitive {
// TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked // TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
} }
return "local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) + "\n" + cast, nil code := emptyCode()
code.add("local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index))
code.addAll(cast)
return code, nil
case Expression_Binary: case Expression_Binary:
binary := expr.Value.(BinaryExpression) binary := expr.Value.(BinaryExpression)
@ -451,39 +481,39 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
watLeft, err := c.compileExpressionWAT(binary.Left) watLeft, err := c.compileExpressionWAT(binary.Left)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
watRight, err := c.compileExpressionWAT(binary.Right) watRight, err := c.compileExpressionWAT(binary.Right)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
op := c.compileOperationWAT(binary.Operation, operandType) op := c.compileOperationWAT(binary.Operation, operandType)
return watLeft + watRight + op + getTypeCast(exprType), nil return concat(watLeft, watRight, op, getTypeCast(exprType)), nil
case Expression_Tuple: case Expression_Tuple:
tuple := expr.Value.(TupleExpression) tuple := expr.Value.(TupleExpression)
wat := "" wat := emptyCode()
for _, member := range tuple.Members { for _, member := range tuple.Members {
memberWAT, err := c.compileExpressionWAT(member) memberWAT, err := c.compileExpressionWAT(member)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += memberWAT wat.addAll(memberWAT)
} }
return wat, nil return wat, nil
case Expression_FunctionCall: case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression) fc := expr.Value.(FunctionCallExpression)
wat := "" wat := emptyCode()
if fc.Parameters != nil { if fc.Parameters != nil {
wat, err = c.compileExpressionWAT(*fc.Parameters) wat, err = c.compileExpressionWAT(*fc.Parameters)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
} }
@ -492,39 +522,40 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
if !c.Wasm64 { if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_U64, Primitive_U32) cast, err := castPrimitiveWAT(Primitive_U64, Primitive_U32)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += cast wat.addAll(cast)
} }
wat += "memory.grow\n" wat.add("memory.grow")
if !c.Wasm64 { if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_I32, Primitive_I64) cast, err := castPrimitiveWAT(Primitive_I32, Primitive_I64)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += cast wat.addAll(cast)
} }
return wat, nil return wat, nil
case BUILTIN_MEMORY_SIZE: case BUILTIN_MEMORY_SIZE:
wat += "memory.size\n" wat.add("memory.size")
if !c.Wasm64 { if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_U32, Primitive_U64) cast, err := castPrimitiveWAT(Primitive_U32, Primitive_U64)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += cast wat.addAll(cast)
} }
return wat, nil return wat, nil
default: default:
return wat + "call $" + fc.Function + "\n", nil wat.add("call $" + fc.Function)
return wat, nil
} }
case Expression_Unary: case Expression_Unary:
unary := expr.Value.(UnaryExpression) unary := expr.Value.(UnaryExpression)
@ -532,35 +563,46 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
wat, err := c.compileExpressionWAT(unary.Value) wat, err := c.compileExpressionWAT(unary.Value)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
watType := getPrimitiveWATType(exprType) watType := getPrimitiveWATType(exprType)
switch unary.Operation { switch unary.Operation {
case UnaryOperation_Negate: case UnaryOperation_Negate:
if isFloatingPoint(exprType) { if isFloatingPoint(exprType) {
return wat + watType + ".neg\n", nil wat.add(watType + ".neg")
return wat, nil
} else { } else {
return watType + ".const 0\n" + wat + watType + ".sub\n" + getTypeCast(exprType), nil code := emptyCode()
code.add(watType + ".const 0")
code.addAll(wat)
code.add(watType + ".sub")
code.addAll(getTypeCast(exprType))
return code, nil
} }
case UnaryOperation_Nop: case UnaryOperation_Nop:
return wat, nil return wat, nil
case UnaryOperation_BitwiseNot: case UnaryOperation_BitwiseNot:
if getBits(exprType) == 64 { if getBits(exprType) == 64 {
return wat + watType + ".const 0xFFFFFFFFFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil wat.add(watType + ".const 0xFFFFFFFFFFFFFFFF")
} else { } else {
return wat + watType + ".const 0xFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil wat.add(watType + ".const 0xFFFFFFFF")
} }
wat.add(watType + ".xor")
wat.addAll(getTypeCast(exprType))
return wat, nil
case UnaryOperation_LogicalNot: case UnaryOperation_LogicalNot:
return wat + "i32.eqz\n", nil wat.add("i32.eqz")
return wat, nil
case UnaryOperation_PreIncrement, UnaryOperation_PreDecrement, UnaryOperation_PostIncrement, UnaryOperation_PostDecrement: case UnaryOperation_PreIncrement, UnaryOperation_PreDecrement, UnaryOperation_PostIncrement, UnaryOperation_PostDecrement:
valueType := c.getWATType(*unary.Value.ValueType) valueType := c.getWATType(*unary.Value.ValueType)
updateWAT := valueType + ".const 1\n" updateWAT := ofInstruction(valueType + ".const 1")
if unary.Operation == UnaryOperation_PreIncrement || unary.Operation == UnaryOperation_PostIncrement { if unary.Operation == UnaryOperation_PreIncrement || unary.Operation == UnaryOperation_PostIncrement {
updateWAT += valueType + ".add\n" updateWAT.add(valueType + ".add")
} else { } else {
updateWAT += valueType + ".sub\n" updateWAT.add(valueType + ".sub")
} }
return c.compileAssignmentUpdateExpressionWAT(unary.Value, updateWAT, unary.Operation == UnaryOperation_PostIncrement || unary.Operation == UnaryOperation_PostDecrement) return c.compileAssignmentUpdateExpressionWAT(unary.Value, updateWAT, unary.Operation == UnaryOperation_PostIncrement || unary.Operation == UnaryOperation_PostDecrement)
@ -570,7 +612,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
wat, err := c.compileExpressionWAT(cast.Value) wat, err := c.compileExpressionWAT(cast.Value)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
// TODO: fine, as it is currently only allowed for primitive types // TODO: fine, as it is currently only allowed for primitive types
@ -578,20 +620,20 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
toType := cast.Type.Value.(PrimitiveType) toType := cast.Type.Value.(PrimitiveType)
castWAT, err := castPrimitiveWAT(fromType, toType) castWAT, err := castPrimitiveWAT(fromType, toType)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
return wat + castWAT, nil return concat(wat, castWAT), nil
case Expression_RawMemoryReference: case Expression_RawMemoryReference:
raw := expr.Value.(RawMemoryReferenceExpression) raw := expr.Value.(RawMemoryReferenceExpression)
wat, err := c.compileExpressionWAT(raw.Address) wat, err := c.compileExpressionWAT(raw.Address)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
if raw.Type.Type == Type_Primitive { if raw.Type.Type == Type_Primitive {
wat += c.getWATType(raw.Type) + ".load\n" // TODO: use load8/load16(_s/u) for smaller types wat.add(c.getWATType(raw.Type) + ".load") // TODO: use load8/load16(_s/u) for smaller types
} }
return wat, nil return wat, nil
@ -606,60 +648,66 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
arrayWAT, err := c.compileExpressionWAT(array.Array) arrayWAT, err := c.compileExpressionWAT(array.Array)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n" arrayWAT.add("local.set $" + strconv.Itoa(localArray.Index))
indexWAT, err := c.compileExpressionWAT(array.Index) indexWAT, err := c.compileExpressionWAT(array.Index)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
if !c.Wasm64 { if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32) cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
indexWAT += cast indexWAT.addAll(cast)
} }
indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n" indexWAT.add("local.set $" + strconv.Itoa(localIndex.Index))
wat := arrayWAT + indexWAT wat := concat(arrayWAT, indexWAT)
if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok {
// Error if index < 0 // Error if index < 0
wat += "block\n" wat.add(
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" "block",
wat += c.getAddressWATType() + ".const 0\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += c.getAddressWATType() + ".ge_s\n" c.getAddressWATType()+".const 0",
wat += "br_if 0\n" c.getAddressWATType()+".ge_s",
wat += "call $__builtin_panic\n" "br_if 0",
wat += "end\n" "call $__builtin_panic",
"end",
)
// Error if index >= array length // Error if index >= array length
wat += "block\n" wat.add(
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" "block",
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += "i32.load\n" // Load array length "local.get $"+strconv.Itoa(localArray.Index),
wat += c.getAddressWATType() + ".lt_s\n" "i32.load", // Load array length
wat += "br_if 0\n" c.getAddressWATType()+".lt_s",
wat += "call $__builtin_panic\n" "br_if 0",
wat += "end\n" "call $__builtin_panic",
"end",
)
} }
elementType := array.Array.ValueType.Value.(ArrayType).ElementType elementType := array.Array.ValueType.Value.(ArrayType).ElementType
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" wat.add(
wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n" "local.get $"+strconv.Itoa(localIndex.Index),
wat += c.getAddressWATType() + ".mul\n" c.getAddressWATType()+".const "+strconv.Itoa(c.getTypeSizeBytes(elementType)),
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" c.getAddressWATType()+".mul",
wat += c.getAddressWATType() + ".add\n" "local.get $"+strconv.Itoa(localArray.Index),
wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length c.getAddressWATType()+".add",
wat += c.getAddressWATType() + ".add\n" c.getAddressWATType()+".const 4", // first 4 bytes = length
c.getAddressWATType()+".add",
)
wat += c.getWATType(elementType) + ".load\n" // TODO: use load8/load16(_s/u) for smaller types wat.add(c.getWATType(elementType) + ".load") // TODO: use load8/load16(_s/u) for smaller types
return wat, nil return wat, nil
} }
@ -667,10 +715,10 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
panic("expr not implemented") panic("expr not implemented")
} }
func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) string { func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) Code {
op := "" var op string
suffix := "" var suffix string
if isUnsignedInt(operandType) { if isUnsignedInt(operandType) {
suffix = "u" suffix = "u"
} else { } else {
@ -679,41 +727,41 @@ func (c *Compiler) compileOperationWAT(operation Operation, operandType Primitiv
switch operation { switch operation {
case Operation_Add: case Operation_Add:
op = getPrimitiveWATType(operandType) + ".add\n" op = getPrimitiveWATType(operandType) + ".add"
case Operation_Sub: case Operation_Sub:
op = getPrimitiveWATType(operandType) + ".sub\n" op = getPrimitiveWATType(operandType) + ".sub"
case Operation_Mul: case Operation_Mul:
op = getPrimitiveWATType(operandType) + ".mul\n" op = getPrimitiveWATType(operandType) + ".mul"
case Operation_Div: case Operation_Div:
op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".div_" + suffix
case Operation_Mod: case Operation_Mod:
op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".rem_" + suffix
case Operation_Greater: case Operation_Greater:
op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".gt_" + suffix
case Operation_Less: case Operation_Less:
op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".lt_" + suffix
case Operation_GreaterEquals: case Operation_GreaterEquals:
op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".ge_" + suffix
case Operation_LessEquals: case Operation_LessEquals:
op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".le_" + suffix
case Operation_NotEquals: case Operation_NotEquals:
op = getPrimitiveWATType(operandType) + ".ne\n" op = getPrimitiveWATType(operandType) + ".ne"
case Operation_Equals: case Operation_Equals:
op = getPrimitiveWATType(operandType) + ".eq\n" op = getPrimitiveWATType(operandType) + ".eq"
default: default:
panic("operation not implemented") panic("operation not implemented")
} }
return op return ofInstruction(op)
} }
func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (Code, error) {
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expr := stmt.Value.(ExpressionStatement) expr := stmt.Value.(ExpressionStatement)
wat, err := c.compileExpressionWAT(expr.Expression) wat, err := c.compileExpressionWAT(expr.Expression)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
numItems := 0 numItems := 0
@ -725,85 +773,91 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
} }
} }
return wat + strings.Repeat("drop\n", numItems), nil for _ = range numItems {
wat.add("drop")
}
return wat, nil
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
wat, err := c.compileBlockWAT(block.Block) wat, err := c.compileBlockWAT(block.Block)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
return wat, nil return wat, nil
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
wat := "" wat := emptyCode()
if ret.Value != nil { if ret.Value != nil {
valueWAT, err := c.compileExpressionWAT(*ret.Value) valueWAT, err := c.compileExpressionWAT(*ret.Value)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += valueWAT wat.addAll(valueWAT)
} }
return wat + "return\n", nil wat.add("return")
return wat, nil
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer == nil { if dlv.Initializer == nil {
return "", nil return emptyCode(), nil
} }
wat, err := c.compileExpressionWAT(*dlv.Initializer) wat, err := c.compileExpressionWAT(*dlv.Initializer)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil wat.add("local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index))
return wat, nil
case Statement_If: case Statement_If:
ifS := stmt.Value.(IfStatement) ifS := stmt.Value.(IfStatement)
conditionWAT, err := c.compileExpressionWAT(ifS.Condition) conditionWAT, err := c.compileExpressionWAT(ifS.Condition)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock) condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat := "" wat := emptyCode()
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
wat += "block\n" wat.add("block")
} }
// condition // condition
wat += "block\n" wat.add("block")
wat += conditionWAT wat.addAll(conditionWAT)
wat += "i32.eqz\n" // logical not wat.add("i32.eqz") // logical not
wat += "br_if 0\n" wat.add("br_if 0")
// condition is true // condition is true
wat += condBlockWAT wat.addAll(condBlockWAT)
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
wat += "br 1\n" // jump over else block wat.add("br 1") // jump over else block
} }
wat += "end\n" wat.add("end")
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
// condition is false // condition is false
elseWAT, err := c.compileBlockWAT(ifS.ElseBlock) elseWAT, err := c.compileBlockWAT(ifS.ElseBlock)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat += elseWAT wat.addAll(elseWAT)
wat += "end\n" wat.add("end")
} }
return wat, nil return wat, nil
@ -812,26 +866,27 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
conditionWAT, err := c.compileExpressionWAT(while.Condition) conditionWAT, err := c.compileExpressionWAT(while.Condition)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
bodyWAT, err := c.compileBlockWAT(while.Body) bodyWAT, err := c.compileBlockWAT(while.Body)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
wat := "block\n" wat := emptyCode()
wat += "loop\n" wat.add("block")
wat.add("loop")
wat += conditionWAT wat.addAll(conditionWAT)
wat += "i32.eqz\n" wat.add("i32.eqz")
wat += "br_if 1\n" wat.add("br_if 1")
wat += bodyWAT wat.addAll(bodyWAT)
wat += "br 0\n" wat.add("br 0")
wat += "end\n" wat.add("end")
wat += "end\n" wat.add("end")
return wat, nil return wat, nil
} }
@ -839,17 +894,17 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
panic("stmt not implemented") panic("stmt not implemented")
} }
func (c *Compiler) compileBlockWAT(block *Block) (string, error) { func (c *Compiler) compileBlockWAT(block *Block) (Code, error) {
blockWAT := "" blockWAT := emptyCode()
for _, stmt := range block.Statements { for _, stmt := range block.Statements {
c.CurrentBlock = block c.CurrentBlock = block
wat, err := c.compileStatementWAT(stmt, block) wat, err := c.compileStatementWAT(stmt, block)
if err != nil { if err != nil {
return "", err return emptyCode(), err
} }
blockWAT += wat blockWAT.addAll(wat)
} }
return blockWAT, nil return blockWAT, nil
@ -890,14 +945,15 @@ func (c *Compiler) compileFunctionWAT(function *ParsedFunction) (string, error)
funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
} }
return funcWAT + blockWat + ")\n", nil return funcWAT + blockWat.toString() + ")\n", nil
} }
func (c *Compiler) compile() (string, error) { func (c *Compiler) compile() (string, error) {
module := "(module (memory $memory 0) (export \"memory\" (memory $memory))\n" module := "(module\n"
module += "(memory $memory 0) (export \"memory\" (memory $memory))\n"
module += "(func $__builtin_panic\n" module += "(func $__builtin_panic\n"
module += "unreachable\n" module += "\tunreachable\n"
module += ")\n" module += ")\n"
for _, file := range c.Files { for _, file := range c.Files {

66
compiler.go Normal file
View File

@ -0,0 +1,66 @@
package main
import "strings"
type Code struct {
Instructions []string
}
func (code *Code) add(instructions ...string) {
code.Instructions = append(code.Instructions, instructions...)
}
func (code *Code) addAll(other Code) {
code.Instructions = append(code.Instructions, other.Instructions...)
}
func ofInstruction(instructions ...string) Code {
return Code{Instructions: instructions}
}
func emptyCode() Code {
return Code{Instructions: []string{}}
}
func clone(code Code) Code {
newInstrs := make([]string, len(code.Instructions))
copy(newInstrs, code.Instructions)
return Code{Instructions: newInstrs}
}
func concat(code Code, other ...Code) Code {
newCode := clone(code)
for _, o := range other {
newCode.addAll(o)
}
return newCode
}
func (code Code) clone() Code {
return clone(code)
}
func (code *Code) isEmpty() bool {
return len(code.Instructions) == 0
}
func (code *Code) toString() string {
text := ""
indent := 1
for _, instr := range code.Instructions {
if instr == "end" {
indent--
}
text += strings.Repeat("\t", indent) + instr + "\n"
if instr == "block" || instr == "loop" {
indent++
}
}
return text
}

2
go.mod
View File

@ -1,3 +1,3 @@
module git.cringe-studios.com/mr/elysium module git.cringe-studios.com/mr/elysium
go 1.21.7 go 1.23.2