diff --git a/backend_wat.go b/backend_wat.go index 8370115..307518c 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -303,7 +303,19 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { switch expr.Type { case Expression_Assignment: ass := expr.Value.(AssignmentExpression) - return c.compileAssignmentExpressionWAT(ass) + + if ass.Operation == Operation_Nop { + return c.compileAssignmentExpressionWAT(ass) + } + + watRight, err := c.compileExpressionWAT(ass.Value) + if err != nil { + return "", err + } + + updateOp := c.compileOperationWAT(ass.Operation, ass.Lhs.ValueType.Value.(PrimitiveType)) + + return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight+updateOp, false) case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -345,41 +357,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { return "", err } - op := "" - - suffix := "" - if isUnsignedInt(operandType) { - suffix = "u" - } else { - suffix = "s" - } - - switch binary.Operation { - case Operation_Add: - op = getPrimitiveWATType(operandType) + ".add\n" - case Operation_Sub: - op = getPrimitiveWATType(operandType) + ".sub\n" - case Operation_Mul: - op = getPrimitiveWATType(operandType) + ".mul\n" - case Operation_Div: - op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" - case Operation_Mod: - op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" - case Operation_Greater: - op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" - case Operation_Less: - op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" - case Operation_GreaterEquals: - op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" - case Operation_LessEquals: - op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" - case Operation_NotEquals: - op = getPrimitiveWATType(operandType) + ".ne\n" - case Operation_Equals: - op = getPrimitiveWATType(operandType) + ".eq\n" - default: - panic("operation not implemented") - } + op := c.compileOperationWAT(binary.Operation, operandType) return watLeft + watRight + op + getTypeCast(exprType), nil case Expression_Tuple: @@ -482,6 +460,46 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { panic("expr not implemented") } +func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) string { + op := "" + + suffix := "" + if isUnsignedInt(operandType) { + suffix = "u" + } else { + suffix = "s" + } + + switch operation { + case Operation_Add: + op = getPrimitiveWATType(operandType) + ".add\n" + case Operation_Sub: + op = getPrimitiveWATType(operandType) + ".sub\n" + case Operation_Mul: + op = getPrimitiveWATType(operandType) + ".mul\n" + case Operation_Div: + op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" + case Operation_Mod: + op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" + case Operation_Greater: + op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" + case Operation_Less: + op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" + case Operation_GreaterEquals: + op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" + case Operation_LessEquals: + op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" + case Operation_NotEquals: + op = getPrimitiveWATType(operandType) + ".ne\n" + case Operation_Equals: + op = getPrimitiveWATType(operandType) + ".eq\n" + default: + panic("operation not implemented") + } + + return op +} + func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { switch stmt.Type { case Statement_Expression: diff --git a/example/test.ely b/example/test.ely index a1653bb..4cb5340 100644 --- a/example/test.ely +++ b/example/test.ely @@ -33,3 +33,8 @@ u64 fib(u64 i) { u64 sub(u64 a) { return --raw(u64, 0xFFu64); } + +u64 assign(u64 a) { + a += 1u64; + return raw(u64, a) += 2u64; +} diff --git a/parser.go b/parser.go index 9c1fef9..cad9b12 100644 --- a/parser.go +++ b/parser.go @@ -101,8 +101,9 @@ type Expression struct { } type AssignmentExpression struct { - Lhs Expression - Value Expression + Lhs Expression + Value Expression + Operation Operation } type LiteralExpression struct { @@ -127,6 +128,7 @@ const ( Operation_LessEquals Operation_NotEquals Operation_Equals + Operation_Nop ) type UnaryOperation uint32 @@ -747,14 +749,10 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return nil, err } - if *op != Operator_Equals { - // TODO: incorrect, evaluates lhs twice - operation := getOperation(*op) - expr = &Expression{Type: Expression_Binary, Value: BinaryExpression{Left: *lhs, Right: *expr, Operation: operation}, Position: lhs.Position} - } + operation := getOperation(*op) *p = pCopy - return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr}, Position: lhs.Position}, nil + return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr, Operation: operation}, Position: lhs.Position}, nil } func (p *Parser) tryExpression() (*Expression, error) { diff --git a/types.go b/types.go index 164d681..a38bf12 100644 --- a/types.go +++ b/types.go @@ -172,6 +172,8 @@ func getOperation(operator Operator) Operation { return Operation_LessEquals case Operator_NotEquals: return Operation_NotEquals + case Operator_Equals: + return Operation_Nop default: return InvalidValue } diff --git a/validator.go b/validator.go index 4e386cc..1c6ac4d 100644 --- a/validator.go +++ b/validator.go @@ -186,6 +186,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } + if assignment.Operation != Operation_Equals && (assignment.Lhs.ValueType.Type != Type_Primitive || assignment.Value.ValueType.Type != Type_Primitive) { + errors = append(errors, v.createError("both sides of an arithmetic expression must evaluate to a primitive type", expr.Position)) + return errors + } + if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position))