Fix assignment expressions

This commit is contained in:
MrLetsplay 2024-04-20 13:56:41 +02:00
parent a960b6d3e7
commit b9c1ad12c5
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
5 changed files with 72 additions and 44 deletions

View File

@ -303,7 +303,19 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
ass := expr.Value.(AssignmentExpression) ass := expr.Value.(AssignmentExpression)
if ass.Operation == Operation_Nop {
return c.compileAssignmentExpressionWAT(ass) return c.compileAssignmentExpressionWAT(ass)
}
watRight, err := c.compileExpressionWAT(ass.Value)
if err != nil {
return "", err
}
updateOp := c.compileOperationWAT(ass.Operation, ass.Lhs.ValueType.Value.(PrimitiveType))
return c.compileAssignmentUpdateExpressionWAT(ass.Lhs, watRight+updateOp, false)
case Expression_Literal: case Expression_Literal:
lit := expr.Value.(LiteralExpression) lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type { switch lit.Literal.Type {
@ -345,41 +357,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
return "", err return "", err
} }
op := "" op := c.compileOperationWAT(binary.Operation, operandType)
suffix := ""
if isUnsignedInt(operandType) {
suffix = "u"
} else {
suffix = "s"
}
switch binary.Operation {
case Operation_Add:
op = getPrimitiveWATType(operandType) + ".add\n"
case Operation_Sub:
op = getPrimitiveWATType(operandType) + ".sub\n"
case Operation_Mul:
op = getPrimitiveWATType(operandType) + ".mul\n"
case Operation_Div:
op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n"
case Operation_Mod:
op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n"
case Operation_Greater:
op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n"
case Operation_Less:
op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n"
case Operation_GreaterEquals:
op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n"
case Operation_LessEquals:
op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n"
case Operation_NotEquals:
op = getPrimitiveWATType(operandType) + ".ne\n"
case Operation_Equals:
op = getPrimitiveWATType(operandType) + ".eq\n"
default:
panic("operation not implemented")
}
return watLeft + watRight + op + getTypeCast(exprType), nil return watLeft + watRight + op + getTypeCast(exprType), nil
case Expression_Tuple: case Expression_Tuple:
@ -482,6 +460,46 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
panic("expr not implemented") panic("expr not implemented")
} }
func (c *Compiler) compileOperationWAT(operation Operation, operandType PrimitiveType) string {
op := ""
suffix := ""
if isUnsignedInt(operandType) {
suffix = "u"
} else {
suffix = "s"
}
switch operation {
case Operation_Add:
op = getPrimitiveWATType(operandType) + ".add\n"
case Operation_Sub:
op = getPrimitiveWATType(operandType) + ".sub\n"
case Operation_Mul:
op = getPrimitiveWATType(operandType) + ".mul\n"
case Operation_Div:
op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n"
case Operation_Mod:
op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n"
case Operation_Greater:
op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n"
case Operation_Less:
op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n"
case Operation_GreaterEquals:
op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n"
case Operation_LessEquals:
op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n"
case Operation_NotEquals:
op = getPrimitiveWATType(operandType) + ".ne\n"
case Operation_Equals:
op = getPrimitiveWATType(operandType) + ".eq\n"
default:
panic("operation not implemented")
}
return op
}
func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) {
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:

View File

@ -33,3 +33,8 @@ u64 fib(u64 i) {
u64 sub(u64 a) { u64 sub(u64 a) {
return --raw(u64, 0xFFu64); return --raw(u64, 0xFFu64);
} }
u64 assign(u64 a) {
a += 1u64;
return raw(u64, a) += 2u64;
}

View File

@ -103,6 +103,7 @@ type Expression struct {
type AssignmentExpression struct { type AssignmentExpression struct {
Lhs Expression Lhs Expression
Value Expression Value Expression
Operation Operation
} }
type LiteralExpression struct { type LiteralExpression struct {
@ -127,6 +128,7 @@ const (
Operation_LessEquals Operation_LessEquals
Operation_NotEquals Operation_NotEquals
Operation_Equals Operation_Equals
Operation_Nop
) )
type UnaryOperation uint32 type UnaryOperation uint32
@ -747,14 +749,10 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) {
return nil, err return nil, err
} }
if *op != Operator_Equals {
// TODO: incorrect, evaluates lhs twice
operation := getOperation(*op) operation := getOperation(*op)
expr = &Expression{Type: Expression_Binary, Value: BinaryExpression{Left: *lhs, Right: *expr, Operation: operation}, Position: lhs.Position}
}
*p = pCopy *p = pCopy
return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr}, Position: lhs.Position}, nil return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr, Operation: operation}, Position: lhs.Position}, nil
} }
func (p *Parser) tryExpression() (*Expression, error) { func (p *Parser) tryExpression() (*Expression, error) {

View File

@ -172,6 +172,8 @@ func getOperation(operator Operator) Operation {
return Operation_LessEquals return Operation_LessEquals
case Operator_NotEquals: case Operator_NotEquals:
return Operation_NotEquals return Operation_NotEquals
case Operator_Equals:
return Operation_Nop
default: default:
return InvalidValue return InvalidValue
} }

View File

@ -186,6 +186,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
if assignment.Operation != Operation_Equals && (assignment.Lhs.ValueType.Type != Type_Primitive || assignment.Value.ValueType.Type != Type_Primitive) {
errors = append(errors, v.createError("both sides of an arithmetic expression must evaluate to a primitive type", expr.Position))
return errors
}
if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position)) errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position))