diff --git a/backend_wat.go b/backend_wat.go index f98421b..edb3d64 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -180,8 +180,8 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { } return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil - case Expression_Arithmetic: - arith := expr.Value.(ArithmeticExpression) + case Expression_Binary: + arith := expr.Value.(BinaryExpression) log.Printf("%+#v", arith) @@ -228,6 +228,8 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n" case Arithmetic_Mod: op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n" + default: + panic("operation not implemented") } return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil @@ -319,7 +321,7 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) { return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil } - return "", nil + panic("stmt not implemented") } func compileBlockWAT(block Block) (string, error) { diff --git a/example/add.lang b/example/add.lang index a30748c..a951ef5 100644 --- a/example/add.lang +++ b/example/add.lang @@ -2,7 +2,11 @@ u64 add(u8 a, u8 b) { return add(a - 1u8, a); } -u64 addわ(u64 a, u64 b) { +u64 add2(u64 a, u64 b) { + if(a == b) { + return 0; + } + return a + b; } diff --git a/lexer.go b/lexer.go index 3675cec..4594e64 100644 --- a/lexer.go +++ b/lexer.go @@ -1,6 +1,7 @@ package main import ( + "log" "slices" "strconv" "strings" @@ -21,7 +22,7 @@ const ( type Keyword uint32 -var Keywords []string = []string{"import", "void", "return", "true", "false"} +var Keywords []string = []string{"import", "void", "return", "true", "false", "if", "else"} const ( Keyword_Import Keyword = iota @@ -29,6 +30,8 @@ const ( Keyword_Return Keyword_True KeyWord_False + Keyword_If + Keyword_Else ) type Separator uint32 @@ -46,7 +49,7 @@ const ( type Operator uint32 -var Operators []rune = []rune{'=', '>', '<', '!', '+', '-', '*', '/', '%'} +var Operators []string = []string{"=", ">", "<", "!", "+", "-", "*", "/", "%", "==", ">=", "<=", "!=", "+=", "-=", "*=", "/=", "%="} const ( Operator_Equals Operator = iota @@ -58,6 +61,15 @@ const ( Operator_Multiply Operator_Divide Operator_Modulo + Operator_EqualsEquals + Operator_GreaterEquals + Operator_LessEquals + Operator_NotEquals + Operator_PlusEquals + Operator_MinusEquals + Operator_MultiplyEquals + Operator_DivideEquals + Operator_ModuloEquals ) type LiteralType uint32 @@ -98,6 +110,30 @@ func (l *Lexer) peekRune() *rune { return &l.Runes[0] } +func (l *Lexer) tryOperator() Operator { + var foundOp Operator = InvalidValue + var foundOpLen int = 0 + + str := string(l.Runes) + for i, operator := range Operators { + operatorLen := len([]rune(operator)) + if operatorLen <= foundOpLen { + continue + } + + if strings.HasPrefix(str, operator) { + foundOp = Operator(i) + foundOpLen = len([]rune(operator)) + } + } + + for i := 0; i < foundOpLen; i++ { + l.nextRune() + } + + return foundOp +} + func (l *Lexer) nextRune() *rune { if len(l.Runes) == 0 { return nil @@ -145,12 +181,12 @@ func (l *Lexer) stringLiteral() (string, error) { } // TODO: maybe this method should directly return LexToken -func (l *Lexer) nextToken() (string, error) { +func (l *Lexer) nextToken() (*LexToken, error) { // Skip whitespace for { r := l.peekRune() if r == nil { - return "", nil + return nil, nil } if !slices.Contains(Whitespace, *r) { @@ -164,55 +200,40 @@ func (l *Lexer) nextToken() (string, error) { r := l.peekRune() if r == nil { - return "", nil + return nil, nil } if *r == '"' { literal, err := l.stringLiteral() if err != nil { - return "", err + return nil, err } - return "\"" + literal + "\"", nil + return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: literal}}, nil + } + + op := l.tryOperator() + if op != InvalidValue { + return &LexToken{Type: Type_Operator, Position: l.LastTokenPosition, Value: op}, nil } token := "" for { r := l.peekRune() - if r == nil || slices.Contains(Whitespace, *r) || slices.Contains(Separators, *r) || slices.Contains(Operators, *r) { + if r == nil || slices.Contains(Whitespace, *r) || slices.Contains(Separators, *r) { break } token += string(*l.nextRune()) } - if len(token) == 0 && len(l.Runes) != 0 { - return string(*l.nextRune()), nil - } + if len(token) == 0 { + if len(l.Runes) == 0 { + println("E3") + return nil, nil + } - return token, nil -} - -func parseNumber(raw string, numberType PrimitiveType) (any, error) { - // TODO: return compiler errors - if isSignedInt(numberType) { - return strconv.ParseInt(raw, 10, getBits(numberType)) - } - - if isUnsignedInt(numberType) { - return strconv.ParseUint(raw, 10, getBits(numberType)) - } - - if isFloatingPoint(numberType) { - return strconv.ParseFloat(raw, getBits(numberType)) - } - - panic("Unhandled type (" + strconv.FormatUint(uint64(numberType), 10) + ") in parseNumber()") -} - -func (l *Lexer) parseToken(token string) (*LexToken, error) { - if strings.HasPrefix(token, "\"") { - return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: token[1 : len(token)-1]}}, nil + token = string(*l.nextRune()) } runes := []rune(token) @@ -254,10 +275,6 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { if idx := slices.Index(Separators, runes[0]); idx != -1 { return &LexToken{Type: Type_Separator, Position: l.LastTokenPosition, Value: Separator(idx)}, nil } - - if idx := slices.Index(Operators, runes[0]); idx != -1 { - return &LexToken{Type: Type_Operator, Position: l.LastTokenPosition, Value: Operator(idx)}, nil - } } if idx := slices.Index(Keywords, token); idx != -1 { @@ -267,6 +284,23 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { return &LexToken{Type: Type_Identifier, Position: l.LastTokenPosition, Value: token}, nil } +func parseNumber(raw string, numberType PrimitiveType) (any, error) { + // TODO: return compiler errors + if isSignedInt(numberType) { + return strconv.ParseInt(raw, 10, getBits(numberType)) + } + + if isUnsignedInt(numberType) { + return strconv.ParseUint(raw, 10, getBits(numberType)) + } + + if isFloatingPoint(numberType) { + return strconv.ParseFloat(raw, getBits(numberType)) + } + + panic("Unhandled type (" + strconv.FormatUint(uint64(numberType), 10) + ") in parseNumber()") +} + func lexer(program string) ([]LexToken, error) { var tokens []LexToken @@ -274,20 +308,16 @@ func lexer(program string) ([]LexToken, error) { for { token, err := lexer.nextToken() + log.Printf("%+#v %+#v", token, err) if err != nil { return nil, err } - if len(token) == 0 { + if token == nil { break } - lexToken, err := lexer.parseToken(token) - if err != nil { - return nil, err - } - - tokens = append(tokens, *lexToken) + tokens = append(tokens, *token) } return tokens, nil diff --git a/parser.go b/parser.go index 8c1db03..e644059 100644 --- a/parser.go +++ b/parser.go @@ -39,6 +39,7 @@ const ( Statement_Block Statement_Return Statement_DeclareLocalVariable + Statement_If ) type Statement struct { @@ -65,13 +66,19 @@ type DeclareLocalVariableStatement struct { Initializer *Expression } +type IfStatement struct { + Condition Expression + ConditionalBlock Block + ElseBlock *Block +} + type ExpressionType uint32 const ( Expression_Assignment ExpressionType = iota Expression_Literal Expression_VariableReference - Expression_Arithmetic + Expression_Binary Expression_Tuple Expression_FunctionCall Expression_Negate @@ -105,9 +112,16 @@ const ( Arithmetic_Mul Arithmetic_Div Arithmetic_Mod + Arithmetic_Greater + Arithmetic_Less + Arithmetic_GreaterEquals + Arithmetic_LessEquals + Arithmetic_LogicalNot + Arithmetic_NotEquals + Arithmetic_Equals ) -type ArithmeticExpression struct { +type BinaryExpression struct { Operation ArithmeticOperation Left Expression Right Expression @@ -502,23 +516,7 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { return nil, p.error("expected expression") } - var operation ArithmeticOperation - switch *op { - case Operator_Multiply: - operation = Arithmetic_Mul - case Operator_Divide: - operation = Arithmetic_Div - case Operator_Plus: - operation = Arithmetic_Add - case Operator_Minus: - operation = Arithmetic_Sub - case Operator_Modulo: - fallthrough - default: - operation = Arithmetic_Mod - } - - left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}, Position: left.Position} + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getArithmeticOperation(*op), Left: *left, Right: *right}, Position: left.Position} } } @@ -551,23 +549,86 @@ func (p *Parser) tryAdditiveExpression() (*Expression, error) { return nil, p.error("expected expression") } - var operation ArithmeticOperation - if *op == Operator_Plus { - operation = Arithmetic_Add - } else { - operation = Arithmetic_Sub - } - - left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}, Position: left.Position} + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getArithmeticOperation(*op), Left: *left, Right: *right}, Position: left.Position} } } -func (p *Parser) tryArithmeticExpression() (*Expression, error) { - return p.tryAdditiveExpression() +func (p *Parser) tryRelationalExpression() (*Expression, error) { + left, err := p.tryAdditiveExpression() + if err != nil { + return nil, err + } + + if left == nil { + return nil, nil + } + + for { + op, err := p.tryOperator(Operator_Less, Operator_Greater, Operator_LessEquals, Operator_GreaterEquals) + if err != nil { + return nil, err + } + + if op == nil { + return left, nil + } + + right, err := p.tryAdditiveExpression() + if err != nil { + return nil, err + } + + if right == nil { + return nil, p.error("expected expression") + } + + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getArithmeticOperation(*op), Left: *left, Right: *right}, Position: left.Position} + } +} + +func (p *Parser) tryEqualityExpression() (*Expression, error) { + left, err := p.tryRelationalExpression() + if err != nil { + return nil, err + } + + if left == nil { + return nil, nil + } + + for { + op, err := p.tryOperator(Operator_EqualsEquals, Operator_NotEquals) + if err != nil { + return nil, err + } + + if op == nil { + return left, nil + } + + right, err := p.tryRelationalExpression() + if err != nil { + return nil, err + } + + if right == nil { + return nil, p.error("expected expression") + } + + left = &Expression{Type: Expression_Binary, Value: BinaryExpression{Operation: getArithmeticOperation(*op), Left: *left, Right: *right}, Position: left.Position} + } +} + +func (p *Parser) tryBinaryExpression() (*Expression, error) { + return p.tryEqualityExpression() } func (p *Parser) tryExpression() (*Expression, error) { - return p.tryArithmeticExpression() + return p.tryBinaryExpression() +} + +func (p *Parser) expectExpression() (*Expression, error) { + return p.expect(p.tryExpression, "expected expression") } func (p *Parser) tryTupleExpression() (*Expression, error) { @@ -698,6 +759,38 @@ func (p *Parser) expectStatement() (*Statement, error) { return &Statement{Type: Statement_Return, Value: ReturnStatement{Value: expr}, Position: token.Position}, nil } + if token.Type == Type_Keyword && token.Value.(Keyword) == Keyword_If { + p.nextToken() + + _, err := p.expectSeparator(Separator_OpenParen) + if err != nil { + return nil, err + } + + cond, err := p.expectExpression() + if err != nil { + return nil, err + } + + _, err = p.expectSeparator(Separator_CloseParen) + if err != nil { + return nil, err + } + + conditionalBlock, err := p.expectBlock() + if err != nil { + return nil, err + } + + tok := p.peekToken() + if tok == nil || tok.Type != Type_Keyword || tok.Value.(Keyword) != Keyword_Else { + return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: *conditionalBlock, ElseBlock: nil}, Position: token.Position}, nil + } + + p.nextToken() + // TODO: else block + } + if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly { block, err := p.expectBlock() if err != nil { diff --git a/types.go b/types.go index 2d485b2..063f1fd 100644 --- a/types.go +++ b/types.go @@ -101,3 +101,43 @@ func getPrimitiveTypeByName(name string) (PrimitiveType, error) { return PrimitiveType(idx), nil } + +func isAssigmentOperator(operator Operator) bool { + switch operator { + case Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals: + return true + default: + return false + } +} + +func getArithmeticOperation(operator Operator) ArithmeticOperation { + switch operator { + case Operator_Greater: + return Arithmetic_Greater + case Operator_Less: + return Arithmetic_Less + case Operator_Not: + return Arithmetic_LogicalNot + case Operator_Plus, Operator_PlusEquals: + return Arithmetic_Add + case Operator_Minus, Operator_MinusEquals: + return Arithmetic_Sub + case Operator_Multiply, Operator_MultiplyEquals: + return Arithmetic_Mul + case Operator_Divide, Operator_DivideEquals: + return Arithmetic_Div + case Operator_Modulo, Operator_ModuloEquals: + return Arithmetic_Mod + case Operator_EqualsEquals: + return Arithmetic_Equals + case Operator_GreaterEquals: + return Arithmetic_GreaterEquals + case Operator_LessEquals: + return Arithmetic_LessEquals + case Operator_NotEquals: + return Arithmetic_NotEquals + default: + return InvalidValue + } +} diff --git a/validator.go b/validator.go index 12350de..582ae5b 100644 --- a/validator.go +++ b/validator.go @@ -117,8 +117,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B } expr.ValueType = &local.Type - case Expression_Arithmetic: - arithmethic := expr.Value.(ArithmeticExpression) + case Expression_Binary: + arithmethic := expr.Value.(BinaryExpression) errors = append(errors, v.validateExpression(&arithmethic.Left, block)...) errors = append(errors, v.validateExpression(&arithmethic.Right, block)...)