From f5168a73bff7d47a76b93192cbb10191a2737e9f Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Thu, 4 Apr 2024 20:05:57 +0200 Subject: [PATCH] Implement while loops, Fix variable references --- backend_wat.go | 40 ++++++++++++++++++++++++++++++++++------ example/test.ely | 13 +++++++------ lexer.go | 3 ++- parser.go | 39 ++++++++++++++++++++++++++++++++++++++- validator.go | 15 +++++++++++++++ 5 files changed, 96 insertions(+), 14 deletions(-) diff --git a/backend_wat.go b/backend_wat.go index 8860507..8aab8f8 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "strconv" "strings" "unicode" @@ -29,7 +30,7 @@ func getPrimitiveWATType(primitive PrimitiveType) string { return "i32" } - panic("unhandled type") + panic(fmt.Sprintf("unhandled type in getPrimitiveWATType(): %s", primitive)) } func safeASCIIIdentifier(identifier string) string { @@ -79,7 +80,7 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string { return "f64.const " + strconv.FormatFloat(value.(float64), 'f', -1, 64) + "\n" } - panic("invalid type") + panic(fmt.Sprintf("invalid type passed to pushConstantNumberWAT(): %s", primitive)) } func (c *Compiler) getAddressWATType() string { @@ -97,10 +98,10 @@ func (c *Compiler) getWATType(t Type) string { case Type_Named, Type_Array: return c.getAddressWATType() case Type_Tuple: - panic("tuple type passed to getWATType()") + panic(fmt.Sprintf("tuple type passed to getWATType(): %s", t)) } - panic("type not implemented in getWATType()") + panic(fmt.Sprintf("type not implemented in getWATType(): %s", t)) } func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { @@ -182,7 +183,7 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio switch lhs.Type { case Expression_VariableReference: ref := lhs.Value.(VariableReferenceExpression) - local := strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index) + local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) return exprWAT + "local.tee $" + local + "\n", nil case Expression_ArrayAccess: panic("TODO") // TODO @@ -241,7 +242,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) } - return "local.get $" + strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index) + "\n" + cast, nil + return "local.get $" + strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) + "\n" + cast, nil case Expression_Binary: binary := expr.Value.(BinaryExpression) @@ -467,6 +468,33 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er wat += "end\n" } + return wat, nil + case Statement_WhileLoop: + while := stmt.Value.(WhileLoopStatement) + + conditionWAT, err := c.compileExpressionWAT(while.Condition) + if err != nil { + return "", err + } + + bodyWAT, err := c.compileBlockWAT(while.Body) + if err != nil { + return "", err + } + + wat := "block\n" + wat += "loop\n" + + wat += conditionWAT + wat += "i32.eqz\n" + wat += "br_if 1\n" + + wat += bodyWAT + wat += "br 0\n" + + wat += "end\n" + wat += "end\n" + return wat, nil } diff --git a/example/test.ely b/example/test.ely index 687b063..f748a07 100644 --- a/example/test.ely +++ b/example/test.ely @@ -1,7 +1,8 @@ -void b(u64 x) { - raw(u64, 0x69u64) = x; -} - -u64 a() { - return raw(u64, 0x69u64); +u32 c(u32 n) { + u32 sum = 0u32; + while(n > 0u32) { + sum += n; + n -= 1u32; + } + return sum; } diff --git a/lexer.go b/lexer.go index 5b43fc4..5a5e8ca 100644 --- a/lexer.go +++ b/lexer.go @@ -22,7 +22,7 @@ const ( type Keyword uint32 -var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else", "raw"} +var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else", "raw", "while"} const ( Keyword_Import Keyword = iota @@ -34,6 +34,7 @@ const ( Keyword_If Keyword_Else Keyword_Raw + Keyword_While ) type Separator uint32 diff --git a/parser.go b/parser.go index 119a13d..d0478a2 100644 --- a/parser.go +++ b/parser.go @@ -40,6 +40,7 @@ const ( Statement_Return Statement_DeclareLocalVariable Statement_If + Statement_WhileLoop ) type Statement struct { @@ -72,6 +73,11 @@ type IfStatement struct { ElseBlock *Block } +type WhileLoopStatement struct { + Condition Expression + Body *Block +} + type ExpressionType uint32 const ( @@ -681,7 +687,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return p.tryBinaryExpression() } - op, err := pCopy.tryOperator(Operator_Equals) + op, err := pCopy.tryOperator(Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals) if err != nil { return nil, err } @@ -695,6 +701,11 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return nil, err } + if *op != Operator_Equals { + operation := getOperation(*op) + expr = &Expression{Type: Expression_Binary, Value: BinaryExpression{Left: *lhs, Right: *expr, Operation: operation}, Position: lhs.Position} + } + *p = pCopy return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr}, Position: lhs.Position}, nil } @@ -885,6 +896,32 @@ func (p *Parser) expectStatement(block *Block) (*Statement, error) { return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: conditionalBlock, ElseBlock: elseBlock}, Position: token.Position}, nil } + if token.Type == Type_Keyword && token.Value.(Keyword) == Keyword_While { + p.nextToken() + + _, err := p.expectSeparator(Separator_OpenParen) + if err != nil { + return nil, err + } + + cond, err := p.expectExpression() + if err != nil { + return nil, err + } + + _, err = p.expectSeparator(Separator_CloseParen) + if err != nil { + return nil, err + } + + body, err := p.expectBlock(block) + if err != nil { + return nil, err + } + + return &Statement{Type: Statement_WhileLoop, Value: WhileLoopStatement{Condition: *cond, Body: body}, Position: token.Position}, nil + } + if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly { block, err := p.expectBlock(block) if err != nil { diff --git a/validator.go b/validator.go index 4b560de..0ec95f3 100644 --- a/validator.go +++ b/validator.go @@ -462,6 +462,21 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) } stmt.Value = ifS + case Statement_WhileLoop: + while := stmt.Value.(WhileLoopStatement) + + errors = append(errors, v.validateExpression(&while.Condition)...) + errors = append(errors, v.validateBlock(while.Body, functionLocals)...) + + if len(errors) != 0 { + return errors + } + + if while.Condition.ValueType.Type != Type_Primitive || while.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool { + errors = append(errors, v.createError("condition must evaluate to boolean", while.Condition.Position)) + } + + stmt.Value = while default: panic("stmt not implemented") }