diff --git a/backend_wat.go b/backend_wat.go index 6baecd8..4b77920 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -2,7 +2,6 @@ package main import ( "errors" - "log" "strconv" ) @@ -60,7 +59,7 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string { case Primitive_I64: return "i64.const " + strconv.FormatInt(value.(int64), 10) + "\n" case Primitive_U64: - return "u64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" + return "i64.const " + strconv.FormatUint(value.(uint64), 10) + "\n" case Primitive_F32: return "f32.const " + strconv.FormatFloat(value.(float64), 'f', -1, 32) + "\n" case Primitive_F64: @@ -70,9 +69,7 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string { panic("invalid type") } -func upcastTypeWAT(from PrimitiveType, to PrimitiveType) (string, error) { - // TODO: refactor - +func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { if from == to { return "", nil } @@ -158,7 +155,13 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { } case Expression_VariableReference: ref := expr.Value.(VariableReferenceExpression) - return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n", nil + + cast := "" + if expr.ValueType.Type == Type_Primitive { + cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) + } + + return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil case Expression_Arithmetic: arith := expr.Value.(ArithmeticExpression) @@ -170,9 +173,7 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { return "", err } - log.Printf("%+#v", arith) - - castLeft, err := upcastTypeWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType) + castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType) if err != nil { return "", err } @@ -182,7 +183,7 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { return "", err } - castRight, err := upcastTypeWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType) + castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType) if err != nil { return "", err } @@ -204,12 +205,12 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { case Arithmetic_Mul: op = getPrimitiveWATType(exprType) + ".mul\n" case Arithmetic_Div: - op = getPrimitiveWATType(exprType) + ".div:" + suffix + "\n" + op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n" case Arithmetic_Mod: - op = getPrimitiveWATType(exprType) + ".rem" + suffix + "\n" + op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n" } - return watLeft + castLeft + watRight + castRight + op + getTypeCast(expr.ValueType.Value.(PrimitiveType)), nil + return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil case Expression_Tuple: } diff --git a/example/add.lang b/example/add.lang index 454736f..439547f 100644 --- a/example/add.lang +++ b/example/add.lang @@ -1,3 +1,3 @@ -u8 add(u8 a, u8 b) { - return a + b; +u64 add(u8 a, u64 b) { + return a * a + b * b; } diff --git a/parser.go b/parser.go index f449b8e..e5b8b79 100644 --- a/parser.go +++ b/parser.go @@ -452,16 +452,15 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { operation = Arithmetic_Mul case Operator_Divide: operation = Arithmetic_Div + case Operator_Plus: + operation = Arithmetic_Add + case Operator_Minus: + operation = Arithmetic_Sub case Operator_Modulo: fallthrough default: operation = Arithmetic_Mod } - if *op == Operator_Plus { - operation = Arithmetic_Add - } else { - operation = Arithmetic_Sub - } return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil } diff --git a/validator.go b/validator.go index d45a8b6..0242ca4 100644 --- a/validator.go +++ b/validator.go @@ -20,22 +20,19 @@ func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { } switch from { - case Primitive_I8: - case Primitive_U8: + case Primitive_I8, Primitive_U8: if to == Primitive_I16 || to == Primitive_U16 { return true } fallthrough - case Primitive_I16: - case Primitive_U16: + case Primitive_I16, Primitive_U16: if to == Primitive_I32 || to == Primitive_U32 { return true } fallthrough - case Primitive_I32: - case Primitive_U32: + case Primitive_I32, Primitive_U32: if to == Primitive_I64 || to == Primitive_U64 { return true } @@ -127,7 +124,7 @@ func validateExpression(expr *Expression, block *Block) []error { } leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) - rightType := arithmethic.Left.ValueType.Value.(PrimitiveType) + rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation) if err != nil { errors = append(errors, err)