diff --git a/backend_wat.go b/backend_wat.go index edb3d64..19c2e63 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -218,15 +218,15 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { } switch arith.Operation { - case Arithmetic_Add: + case Operation_Add: op = getPrimitiveWATType(exprType) + ".add\n" - case Arithmetic_Sub: + case Operation_Sub: op = getPrimitiveWATType(exprType) + ".sub\n" - case Arithmetic_Mul: + case Operation_Mul: op = getPrimitiveWATType(exprType) + ".mul\n" - case Arithmetic_Div: + case Operation_Div: op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n" - case Arithmetic_Mod: + case Operation_Mod: op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n" default: panic("operation not implemented") diff --git a/parser.go b/parser.go index ed57dde..82a4b2c 100644 --- a/parser.go +++ b/parser.go @@ -104,27 +104,28 @@ type VariableReferenceExpression struct { Variable string } -type ArithmeticOperation uint32 +type Operation uint32 const ( - Arithmetic_Add ArithmeticOperation = iota - Arithmetic_Sub - Arithmetic_Mul - Arithmetic_Div - Arithmetic_Mod - Arithmetic_Greater - Arithmetic_Less - Arithmetic_GreaterEquals - Arithmetic_LessEquals - Arithmetic_LogicalNot - Arithmetic_NotEquals - Arithmetic_Equals + Operation_Add Operation = iota + Operation_Sub + Operation_Mul + Operation_Div + Operation_Mod + Operation_Greater + Operation_Less + Operation_GreaterEquals + Operation_LessEquals + Operation_LogicalNot + Operation_NotEquals + Operation_Equals ) type BinaryExpression struct { - Operation ArithmeticOperation - Left Expression - Right Expression + Operation Operation + Left Expression + Right Expression + ResultType *Type } type TupleExpression struct { @@ -516,7 +517,7 @@ func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operat return nil, p.error("expected expression") } - left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getArithmeticOperation(*op), Left: *left, Right: *right}, Position: left.Position} + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getOperation(*op), Left: *left, Right: *right}, Position: left.Position} } } diff --git a/types.go b/types.go index 063f1fd..05c8e94 100644 --- a/types.go +++ b/types.go @@ -111,33 +111,46 @@ func isAssigmentOperator(operator Operator) bool { } } -func getArithmeticOperation(operator Operator) ArithmeticOperation { +func getOperation(operator Operator) Operation { switch operator { case Operator_Greater: - return Arithmetic_Greater + return Operation_Greater case Operator_Less: - return Arithmetic_Less + return Operation_Less case Operator_Not: - return Arithmetic_LogicalNot + return Operation_LogicalNot case Operator_Plus, Operator_PlusEquals: - return Arithmetic_Add + return Operation_Add case Operator_Minus, Operator_MinusEquals: - return Arithmetic_Sub + return Operation_Sub case Operator_Multiply, Operator_MultiplyEquals: - return Arithmetic_Mul + return Operation_Mul case Operator_Divide, Operator_DivideEquals: - return Arithmetic_Div + return Operation_Div case Operator_Modulo, Operator_ModuloEquals: - return Arithmetic_Mod + return Operation_Mod case Operator_EqualsEquals: - return Arithmetic_Equals + return Operation_Equals case Operator_GreaterEquals: - return Arithmetic_GreaterEquals + return Operation_GreaterEquals case Operator_LessEquals: - return Arithmetic_LessEquals + return Operation_LessEquals case Operator_NotEquals: - return Arithmetic_NotEquals + return Operation_NotEquals default: return InvalidValue } } + +func isBooleanOperation(operation Operation) bool { + switch operation { + case Operation_Greater, Operation_Less, Operation_LogicalNot, Operation_Equals, Operation_GreaterEquals, Operation_LessEquals, Operation_NotEquals: + return true + default: + return false + } +} + +func isArithmeticOperation(operation Operation) bool { + return !isBooleanOperation(operation) +} diff --git a/validator.go b/validator.go index 582ae5b..e8b2a9f 100644 --- a/validator.go +++ b/validator.go @@ -58,7 +58,7 @@ func (v *Validator) validateImport(imp *Import) []error { return nil } -func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { +func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation Operation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position) } @@ -118,30 +118,61 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B expr.ValueType = &local.Type case Expression_Binary: - arithmethic := expr.Value.(BinaryExpression) + binary := expr.Value.(BinaryExpression) - errors = append(errors, v.validateExpression(&arithmethic.Left, block)...) - errors = append(errors, v.validateExpression(&arithmethic.Right, block)...) + errors = append(errors, v.validateExpression(&binary.Left, block)...) + errors = append(errors, v.validateExpression(&binary.Right, block)...) if len(errors) != 0 { return errors } - if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { - errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position)) - return errors + if isBooleanOperation(binary.Operation) { + if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { + errors = append(errors, v.createError("cannot compare non-primitive types", expr.Position)) + return errors + } + + leftType := binary.Left.ValueType.Value.(PrimitiveType) + rightType := binary.Right.ValueType.Value.(PrimitiveType) + + var result PrimitiveType = InvalidValue + if isPrimitiveTypeExpandableTo(leftType, rightType) { + result = leftType + } + + if isPrimitiveTypeExpandableTo(rightType, leftType) { + result = leftType + } + + if result == InvalidValue { + errors = append(errors, v.createError("cannot compare these types without explicit cast", expr.Position)) + return errors + } + + binary.ResultType = &Type{Type: Type_Primitive, Value: result} + expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool} } - leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) - rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) - result, err := v.getArithmeticResultType(expr, leftType, rightType, arithmethic.Operation) - if err != nil { - errors = append(errors, err) - return errors + if isArithmeticOperation(binary.Operation) { + if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { + errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position)) + return errors + } + + leftType := binary.Left.ValueType.Value.(PrimitiveType) + rightType := binary.Right.ValueType.Value.(PrimitiveType) + result, err := v.getArithmeticResultType(expr, leftType, rightType, binary.Operation) + if err != nil { + errors = append(errors, err) + return errors + } + + binary.ResultType = &Type{Type: Type_Primitive, Value: result} + expr.ValueType = &Type{Type: Type_Primitive, Value: result} } - expr.ValueType = &Type{Type: Type_Primitive, Value: result} - expr.Value = arithmethic + expr.Value = binary case Expression_Tuple: tuple := expr.Value.(TupleExpression) @@ -223,6 +254,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B expr.ValueType = neg.Value.ValueType expr.Value = neg + default: + panic("expr not implemented") } return errors @@ -275,6 +308,28 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc *functionLocals = append(*functionLocals, local) // TODO: check if assignment of initializer is correct + stmt.Value = dlv + case Statement_If: + ifS := stmt.Value.(IfStatement) + + errors = append(errors, v.validateExpression(&ifS.Condition, block)...) + errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...) + + if ifS.ElseBlock != nil { + errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) + } + + if len(errors) != 0 { + return errors + } + + if ifS.Condition.ValueType.Type != Type_Primitive || ifS.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool { + errors = append(errors, v.createError("condition must evaluate to boolean", ifS.Condition.Position)) + } + + stmt.Value = ifS + default: + panic("stmt not implemented") } return errors