From 0236dfbb2de711596f554094f746fedb2d2c49a6 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Tue, 9 Apr 2024 13:57:17 +0200 Subject: [PATCH] More unary expressions --- backend_wat.go | 28 ++++++++++++++------ example/test.ely | 14 ++++++++++ lexer.go | 5 +++- parser.go | 68 +++++++++++++++++++++++++++--------------------- stdlib/alloc.ely | 3 --- types.go | 19 +++++++++++--- validator.go | 30 ++++++++++++++------- 7 files changed, 114 insertions(+), 53 deletions(-) diff --git a/backend_wat.go b/backend_wat.go index 8aab8f8..0d6eaf6 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -323,23 +323,35 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { } return wat + "call $" + fc.Function + "\n", nil - case Expression_Negate: - neg := expr.Value.(NegateExpression) + case Expression_Unary: + unary := expr.Value.(UnaryExpression) exprType := expr.ValueType.Value.(PrimitiveType) - wat, err := c.compileExpressionWAT(neg.Value) + wat, err := c.compileExpressionWAT(unary.Value) if err != nil { return "", err } watType := getPrimitiveWATType(exprType) - if isSignedInt(exprType) || isUnsignedInt(exprType) { - return watType + ".const 0\n" + wat + watType + ".sub\n", nil + switch unary.Operation { + case UnaryOperation_Negate: + if isFloatingPoint(exprType) { + return wat + watType + ".neg\n", nil + } else { + return watType + ".const 0\n" + wat + watType + ".sub\n" + getTypeCast(exprType), nil + } + case UnaryOperation_Nop: + return wat, nil + case UnaryOperation_BitwiseNot: + if getBits(exprType) == 64 { + return wat + watType + ".const 0xFFFFFFFFFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil + } else { + return wat + watType + ".const 0xFFFFFFFF\n" + watType + ".xor\n" + getTypeCast(exprType), nil + } + case UnaryOperation_LogicalNot: + return wat + "i32.eqz\n", nil } - if isFloatingPoint(exprType) { - return watType + ".neg\n", nil - } case Expression_Cast: cast := expr.Value.(CastExpression) diff --git a/example/test.ely b/example/test.ely index f748a07..9ff23c4 100644 --- a/example/test.ely +++ b/example/test.ely @@ -6,3 +6,17 @@ u32 c(u32 n) { } return sum; } + + +u64 fib(u64 i) { + u64 fibA = 0u64; + u64 fibB = 1u64; + + while((i -= 1u64) > 0u64) { + u64 tmp = fibB; + fibB = fibA + fibB; + fibA = tmp; + } + + return fibA; +} diff --git a/lexer.go b/lexer.go index 5a5e8ca..034caf5 100644 --- a/lexer.go +++ b/lexer.go @@ -55,7 +55,7 @@ const ( type Operator uint32 -var Operators []string = []string{"=", ">", "<", "!", "+", "-", "*", "/", "%", "==", ">=", "<=", "!=", "+=", "-=", "*=", "/=", "%="} +var Operators []string = []string{"=", ">", "<", "!", "+", "-", "*", "/", "%", "==", ">=", "<=", "!=", "+=", "-=", "*=", "/=", "%=", "++", "--", "~"} const ( Operator_Equals Operator = iota @@ -76,6 +76,9 @@ const ( Operator_MultiplyEquals Operator_DivideEquals Operator_ModuloEquals + Operator_PlusPlus + Operator_MinusMinus + Operator_BitwiseNot ) type LiteralType uint32 diff --git a/parser.go b/parser.go index ac46991..5b13e31 100644 --- a/parser.go +++ b/parser.go @@ -87,7 +87,7 @@ const ( Expression_Binary Expression_Tuple Expression_FunctionCall - Expression_Negate + Expression_Unary Expression_ArrayAccess Expression_RawMemoryReference Expression_Cast @@ -125,11 +125,19 @@ const ( Operation_Less Operation_GreaterEquals Operation_LessEquals - Operation_LogicalNot Operation_NotEquals Operation_Equals ) +type UnaryOperation uint32 + +const ( + UnaryOperation_Negate UnaryOperation = iota + UnaryOperation_Nop + UnaryOperation_BitwiseNot + UnaryOperation_LogicalNot +) + type BinaryExpression struct { Operation Operation Left Expression @@ -145,8 +153,9 @@ type FunctionCallExpression struct { Parameters *Expression } -type NegateExpression struct { - Value Expression +type UnaryExpression struct { + Operation UnaryOperation + Value Expression } type ArrayAccessExpression struct { @@ -594,27 +603,27 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { return nil, nil } - if token.Type == Type_Operator { - op := token.Value.(Operator) - if op == Operator_Minus || op == Operator_Plus { - pCopy.nextToken() - expr, err := pCopy.tryPrimaryExpression() - if err != nil { - return nil, err - } - - if expr == nil { - return nil, nil - } - - if op == Operator_Minus { - expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position} - } - - return expr, nil - } + op, err := pCopy.tryOperator(Operator_Minus, Operator_Plus, Operator_BitwiseNot, Operator_Not) + if err != nil { + return nil, err } + if op != nil { + expr, err := pCopy.tryPrimaryExpression() + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + *p = pCopy + return &Expression{Type: Expression_Unary, Value: UnaryExpression{Operation: getUnaryOperation(*op), Value: *expr}, Position: token.Position}, nil + } + + // TODO: pre-/postfix in-/decrement expr + return p.tryPrimaryExpression() } @@ -647,7 +656,12 @@ func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operat return nil, p.error("expected expression") } - left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getOperation(*op), Left: *left, Right: *right}, Position: left.Position} + operation := getOperation(*op) + if operation == InvalidValue { + return nil, p.error("operator not allowed in binary expression") + } + + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: operation, Left: *left, Right: *right}, Position: left.Position} } } @@ -679,11 +693,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return nil, err } - if lhs == nil { - return nil, nil - } - - if lhs.Type != Expression_VariableReference && lhs.Type != Expression_ArrayAccess && lhs.Type != Expression_RawMemoryReference { + if lhs == nil || (lhs.Type != Expression_VariableReference && lhs.Type != Expression_ArrayAccess && lhs.Type != Expression_RawMemoryReference) { return p.tryBinaryExpression() } diff --git a/stdlib/alloc.ely b/stdlib/alloc.ely index 6f39644..57b62ae 100644 --- a/stdlib/alloc.ely +++ b/stdlib/alloc.ely @@ -1,9 +1,6 @@ module alloc; u64 alloc(u64 size) { - u64 ptr = 0x0u64; - raw(i32, ptr) = 0x03u32; - i32 sus = raw(i32, ptr); return 0u64; } diff --git a/types.go b/types.go index 484b9fa..164d681 100644 --- a/types.go +++ b/types.go @@ -154,8 +154,6 @@ func getOperation(operator Operator) Operation { return Operation_Greater case Operator_Less: return Operation_Less - case Operator_Not: - return Operation_LogicalNot case Operator_Plus, Operator_PlusEquals: return Operation_Add case Operator_Minus, Operator_MinusEquals: @@ -179,9 +177,24 @@ func getOperation(operator Operator) Operation { } } +func getUnaryOperation(operator Operator) UnaryOperation { + switch operator { + case Operator_Minus: + return UnaryOperation_Negate + case Operator_Plus: + return UnaryOperation_Nop + case Operator_BitwiseNot: + return UnaryOperation_BitwiseNot + case Operator_Not: + return UnaryOperation_LogicalNot + default: + return InvalidValue + } +} + func isBooleanOperation(operation Operation) bool { switch operation { - case Operation_Greater, Operation_Less, Operation_LogicalNot, Operation_Equals, Operation_GreaterEquals, Operation_LessEquals, Operation_NotEquals: + case Operation_Greater, Operation_Less, Operation_Equals, Operation_GreaterEquals, Operation_LessEquals, Operation_NotEquals: return true default: return false diff --git a/validator.go b/validator.go index 0ec95f3..4e386cc 100644 --- a/validator.go +++ b/validator.go @@ -156,7 +156,7 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType return left, nil } - return InvalidValue, v.createError(fmt.Sprintf("cannot use the types [%s, %s] in an arithmetic expression without an explicit cast", left, right), expr.Position) // TODO: include type names in error + return InvalidValue, v.createError(fmt.Sprintf("cannot use the types [%s, %s] in an arithmetic expression without an explicit cast", left, right), expr.Position) } func getLocal(block *Block, variable string) *Local { @@ -324,24 +324,36 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error } } - // TODO: get function and validate using return type expr.ValueType = calledFunc.ReturnType expr.Value = fc - case Expression_Negate: - neg := expr.Value.(NegateExpression) + case Expression_Unary: + unary := expr.Value.(UnaryExpression) - valErrors := v.validateExpression(&neg.Value) + valErrors := v.validateExpression(&unary.Value) if len(valErrors) != 0 { errors = append(errors, valErrors...) return errors } - if neg.Value.ValueType.Type != Type_Primitive { - errors = append(errors, v.createError("cannot negate non-number types", expr.Position)) + if unary.Value.ValueType.Type != Type_Primitive { + errors = append(errors, v.createError("cannot operate on non-primitive types", expr.Position)) + } else { + primitive := unary.Value.ValueType.Value.(PrimitiveType) + if (unary.Operation == UnaryOperation_Negate || unary.Operation == UnaryOperation_Nop /* + sign */) && !isSignedInt(primitive) && !isFloatingPoint(primitive) { + errors = append(errors, v.createError("can only perform negation/unary plus on signed types", expr.Position)) + } + + if unary.Operation == UnaryOperation_LogicalNot && primitive != Primitive_Bool { + errors = append(errors, v.createError("cannot perform logical not on non-bool type", expr.Position)) + } + + if unary.Operation == UnaryOperation_BitwiseNot && !isUnsignedInt(primitive) && !isSignedInt(primitive) { + errors = append(errors, v.createError("cannot perform bitwise not on non-integer type", expr.Position)) + } } - expr.ValueType = neg.Value.ValueType - expr.Value = neg + expr.ValueType = unary.Value.ValueType + expr.Value = unary case Expression_RawMemoryReference: raw := expr.Value.(RawMemoryReferenceExpression)