From 35ec9e97ca9e096fa7170271dcd2bc74510fab88 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Sun, 24 Mar 2024 20:53:34 +0100 Subject: [PATCH] Compile if-else statements --- backend_wat.go | 8 ++++---- lexer.go | 2 -- parser.go | 51 ++++++++++++++++++++++++++++++++++---------------- validator.go | 35 +++++++++++++++++++++++----------- 4 files changed, 63 insertions(+), 33 deletions(-) diff --git a/backend_wat.go b/backend_wat.go index d72395a..b490490 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -150,7 +150,7 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { return "i32.wrap_i64\n" + getTypeCast(to), nil } -func compileExpressionWAT(expr Expression, block Block) (string, error) { +func compileExpressionWAT(expr Expression, block *Block) (string, error) { var err error switch expr.Type { @@ -291,7 +291,7 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) { panic("expr not implemented") } -func compileStatementWAT(stmt Statement, block Block) (string, error) { +func compileStatementWAT(stmt Statement, block *Block) (string, error) { switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) @@ -368,7 +368,7 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) { if ifS.ElseBlock != nil { // condition is false - elseWAT, err := compileBlockWAT(*ifS.ElseBlock) + elseWAT, err := compileBlockWAT(ifS.ElseBlock) if err != nil { return "", err } @@ -383,7 +383,7 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) { panic("stmt not implemented") } -func compileBlockWAT(block Block) (string, error) { +func compileBlockWAT(block *Block) (string, error) { blockWAT := "" for _, stmt := range block.Statements { diff --git a/lexer.go b/lexer.go index 4594e64..0a1e1aa 100644 --- a/lexer.go +++ b/lexer.go @@ -1,7 +1,6 @@ package main import ( - "log" "slices" "strconv" "strings" @@ -308,7 +307,6 @@ func lexer(program string) ([]LexToken, error) { for { token, err := lexer.nextToken() - log.Printf("%+#v %+#v", token, err) if err != nil { return nil, err } diff --git a/parser.go b/parser.go index 745453f..1ef002d 100644 --- a/parser.go +++ b/parser.go @@ -53,7 +53,7 @@ type ExpressionStatement struct { } type BlockStatement struct { - Block Block + Block *Block } type ReturnStatement struct { @@ -68,7 +68,7 @@ type DeclareLocalVariableStatement struct { type IfStatement struct { Condition Expression - ConditionalBlock Block + ConditionalBlock *Block ElseBlock *Block } @@ -163,7 +163,7 @@ type ParsedFunction struct { Name string Parameters []ParsedParameter ReturnType *Type - Body Block + Body *Block Locals []Local // All of the locals of the function, ordered by their index } @@ -645,7 +645,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) { return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: initializer}, Position: variableType.Position}, nil } -func (p *Parser) expectStatement() (*Statement, error) { +func (p *Parser) expectStatement(block *Block) (*Statement, error) { token := p.peekToken() if token == nil { return nil, p.error("expected statement") @@ -695,27 +695,45 @@ func (p *Parser) expectStatement() (*Statement, error) { return nil, err } - conditionalBlock, err := p.expectBlock() + conditionalBlock, err := p.expectBlock(block) if err != nil { return nil, err } tok := p.peekToken() if tok == nil || tok.Type != Type_Keyword || tok.Value.(Keyword) != Keyword_Else { - return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: *conditionalBlock, ElseBlock: nil}, Position: token.Position}, nil + return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: conditionalBlock, ElseBlock: nil}, Position: token.Position}, nil } p.nextToken() - // TODO: else block + + var elseBlock *Block + + tok = p.peekToken() + if tok.Type == Type_Keyword && tok.Value.(Keyword) == Keyword_If { + stmt, err := p.expectStatement(block) + if err != nil { + return nil, err + } + + elseBlock = &Block{Parent: block, Statements: []Statement{*stmt}} + } else { + elseBlock, err = p.expectBlock(block) + if err != nil { + return nil, err + } + } + + return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: conditionalBlock, ElseBlock: elseBlock}, Position: token.Position}, nil } if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly { - block, err := p.expectBlock() + block, err := p.expectBlock(block) if err != nil { return nil, err } - return &Statement{Type: Statement_Block, Value: BlockStatement{Block: *block}, Position: token.Position}, nil + return &Statement{Type: Statement_Block, Value: BlockStatement{Block: block}, Position: token.Position}, nil } stmt, err := p.tryDeclareLocalVariableStatement() @@ -744,13 +762,14 @@ func (p *Parser) expectStatement() (*Statement, error) { return nil, p.error("expected statement") } -func (p *Parser) expectBlock() (*Block, error) { +func (p *Parser) expectBlock(parent *Block) (*Block, error) { _, err := p.expectSeparator(Separator_OpenCurly) if err != nil { return nil, err } - var statements []Statement + block := &Block{Parent: parent} + for { token := p.peekToken() if token == nil { @@ -762,15 +781,15 @@ func (p *Parser) expectBlock() (*Block, error) { break } - stmt, err := p.expectStatement() + stmt, err := p.expectStatement(block) if err != nil { return nil, err } - statements = append(statements, *stmt) + block.Statements = append(block.Statements, *stmt) } - return &Block{Statements: statements}, nil + return block, nil } func (p *Parser) expectFunction() (*ParsedFunction, error) { @@ -835,12 +854,12 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) { parameters = append(parameters, ParsedParameter{Name: paramName, Type: *paramType}) } - body, err = p.expectBlock() + body, err = p.expectBlock(nil) if err != nil { return nil, err } - return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: *body}, nil + return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body}, nil } func (p *Parser) parseFile() (*ParsedFile, error) { diff --git a/validator.go b/validator.go index d1fee57..fda3962 100644 --- a/validator.go +++ b/validator.go @@ -1,7 +1,6 @@ package main import ( - "log" "strconv" ) @@ -39,7 +38,6 @@ func isTypeExpandableTo(from Type, to Type) bool { return true } - log.Printf("%+#v %+#v", from, to) panic("not implemented") } @@ -103,15 +101,26 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error } +func getLocal(block *Block, variable string) *Local { + if local, ok := block.Locals[variable]; ok { + return &local + } + + if block.Parent == nil { + return nil + } + + return getLocal(block.Parent, variable) +} + func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error { var errors []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) - var local Local - var ok bool - if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok { + local := getLocal(v.currentBlock, assignment.Variable) + if local == nil { errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors } @@ -136,10 +145,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) - var local Local - var ok bool - if local, ok = v.currentBlock.Locals[reference.Variable]; !ok { + local := getLocal(v.currentBlock, reference.Variable) + if local == nil { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) + //panic("E") + println("ERROR") return errors } @@ -313,7 +323,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) - errors = append(errors, v.validateBlock(&block.Block, functionLocals)...) + errors = append(errors, v.validateBlock(block.Block, functionLocals)...) stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) @@ -335,6 +345,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) } else if v.currentFunction.ReturnType != nil { errors = append(errors, v.createError("missing return value", stmt.Position)) } + + stmt.Value = ret case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { @@ -355,7 +367,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) ifS := stmt.Value.(IfStatement) errors = append(errors, v.validateExpression(&ifS.Condition)...) - errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...) + errors = append(errors, v.validateBlock(ifS.ConditionalBlock, functionLocals)...) if ifS.ElseBlock != nil { errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) @@ -386,6 +398,7 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error for i := range block.Statements { v.currentBlock = block + println("CURRBLOCK") stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, functionLocals)...) } @@ -400,7 +413,7 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { v.currentFunction = function - body := &function.Body + body := function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}