Compile if-else statements

This commit is contained in:
MrLetsplay 2024-03-24 20:53:34 +01:00
parent fa63fee64d
commit 35ec9e97ca
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 63 additions and 33 deletions

View File

@ -150,7 +150,7 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
return "i32.wrap_i64\n" + getTypeCast(to), nil return "i32.wrap_i64\n" + getTypeCast(to), nil
} }
func compileExpressionWAT(expr Expression, block Block) (string, error) { func compileExpressionWAT(expr Expression, block *Block) (string, error) {
var err error var err error
switch expr.Type { switch expr.Type {
@ -291,7 +291,7 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
panic("expr not implemented") panic("expr not implemented")
} }
func compileStatementWAT(stmt Statement, block Block) (string, error) { func compileStatementWAT(stmt Statement, block *Block) (string, error) {
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expr := stmt.Value.(ExpressionStatement) expr := stmt.Value.(ExpressionStatement)
@ -368,7 +368,7 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
// condition is false // condition is false
elseWAT, err := compileBlockWAT(*ifS.ElseBlock) elseWAT, err := compileBlockWAT(ifS.ElseBlock)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -383,7 +383,7 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
panic("stmt not implemented") panic("stmt not implemented")
} }
func compileBlockWAT(block Block) (string, error) { func compileBlockWAT(block *Block) (string, error) {
blockWAT := "" blockWAT := ""
for _, stmt := range block.Statements { for _, stmt := range block.Statements {

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"log"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -308,7 +307,6 @@ func lexer(program string) ([]LexToken, error) {
for { for {
token, err := lexer.nextToken() token, err := lexer.nextToken()
log.Printf("%+#v %+#v", token, err)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -53,7 +53,7 @@ type ExpressionStatement struct {
} }
type BlockStatement struct { type BlockStatement struct {
Block Block Block *Block
} }
type ReturnStatement struct { type ReturnStatement struct {
@ -68,7 +68,7 @@ type DeclareLocalVariableStatement struct {
type IfStatement struct { type IfStatement struct {
Condition Expression Condition Expression
ConditionalBlock Block ConditionalBlock *Block
ElseBlock *Block ElseBlock *Block
} }
@ -163,7 +163,7 @@ type ParsedFunction struct {
Name string Name string
Parameters []ParsedParameter Parameters []ParsedParameter
ReturnType *Type ReturnType *Type
Body Block Body *Block
Locals []Local // All of the locals of the function, ordered by their index Locals []Local // All of the locals of the function, ordered by their index
} }
@ -645,7 +645,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) {
return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: initializer}, Position: variableType.Position}, nil return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: initializer}, Position: variableType.Position}, nil
} }
func (p *Parser) expectStatement() (*Statement, error) { func (p *Parser) expectStatement(block *Block) (*Statement, error) {
token := p.peekToken() token := p.peekToken()
if token == nil { if token == nil {
return nil, p.error("expected statement") return nil, p.error("expected statement")
@ -695,27 +695,45 @@ func (p *Parser) expectStatement() (*Statement, error) {
return nil, err return nil, err
} }
conditionalBlock, err := p.expectBlock() conditionalBlock, err := p.expectBlock(block)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tok := p.peekToken() tok := p.peekToken()
if tok == nil || tok.Type != Type_Keyword || tok.Value.(Keyword) != Keyword_Else { if tok == nil || tok.Type != Type_Keyword || tok.Value.(Keyword) != Keyword_Else {
return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: *conditionalBlock, ElseBlock: nil}, Position: token.Position}, nil return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: conditionalBlock, ElseBlock: nil}, Position: token.Position}, nil
} }
p.nextToken() p.nextToken()
// TODO: else block
var elseBlock *Block
tok = p.peekToken()
if tok.Type == Type_Keyword && tok.Value.(Keyword) == Keyword_If {
stmt, err := p.expectStatement(block)
if err != nil {
return nil, err
}
elseBlock = &Block{Parent: block, Statements: []Statement{*stmt}}
} else {
elseBlock, err = p.expectBlock(block)
if err != nil {
return nil, err
}
}
return &Statement{Type: Statement_If, Value: IfStatement{Condition: *cond, ConditionalBlock: conditionalBlock, ElseBlock: elseBlock}, Position: token.Position}, nil
} }
if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly { if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly {
block, err := p.expectBlock() block, err := p.expectBlock(block)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Statement{Type: Statement_Block, Value: BlockStatement{Block: *block}, Position: token.Position}, nil return &Statement{Type: Statement_Block, Value: BlockStatement{Block: block}, Position: token.Position}, nil
} }
stmt, err := p.tryDeclareLocalVariableStatement() stmt, err := p.tryDeclareLocalVariableStatement()
@ -744,13 +762,14 @@ func (p *Parser) expectStatement() (*Statement, error) {
return nil, p.error("expected statement") return nil, p.error("expected statement")
} }
func (p *Parser) expectBlock() (*Block, error) { func (p *Parser) expectBlock(parent *Block) (*Block, error) {
_, err := p.expectSeparator(Separator_OpenCurly) _, err := p.expectSeparator(Separator_OpenCurly)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var statements []Statement block := &Block{Parent: parent}
for { for {
token := p.peekToken() token := p.peekToken()
if token == nil { if token == nil {
@ -762,15 +781,15 @@ func (p *Parser) expectBlock() (*Block, error) {
break break
} }
stmt, err := p.expectStatement() stmt, err := p.expectStatement(block)
if err != nil { if err != nil {
return nil, err return nil, err
} }
statements = append(statements, *stmt) block.Statements = append(block.Statements, *stmt)
} }
return &Block{Statements: statements}, nil return block, nil
} }
func (p *Parser) expectFunction() (*ParsedFunction, error) { func (p *Parser) expectFunction() (*ParsedFunction, error) {
@ -835,12 +854,12 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) {
parameters = append(parameters, ParsedParameter{Name: paramName, Type: *paramType}) parameters = append(parameters, ParsedParameter{Name: paramName, Type: *paramType})
} }
body, err = p.expectBlock() body, err = p.expectBlock(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: *body}, nil return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body}, nil
} }
func (p *Parser) parseFile() (*ParsedFile, error) { func (p *Parser) parseFile() (*ParsedFile, error) {

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"log"
"strconv" "strconv"
) )
@ -39,7 +38,6 @@ func isTypeExpandableTo(from Type, to Type) bool {
return true return true
} }
log.Printf("%+#v %+#v", from, to)
panic("not implemented") panic("not implemented")
} }
@ -103,15 +101,26 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error
} }
func getLocal(block *Block, variable string) *Local {
if local, ok := block.Locals[variable]; ok {
return &local
}
if block.Parent == nil {
return nil
}
return getLocal(block.Parent, variable)
}
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error { func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error {
var errors []error var errors []error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
assignment := expr.Value.(AssignmentExpression) assignment := expr.Value.(AssignmentExpression)
var local Local local := getLocal(v.currentBlock, assignment.Variable)
var ok bool if local == nil {
if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok {
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
return errors return errors
} }
@ -136,10 +145,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
} }
case Expression_VariableReference: case Expression_VariableReference:
reference := expr.Value.(VariableReferenceExpression) reference := expr.Value.(VariableReferenceExpression)
var local Local local := getLocal(v.currentBlock, reference.Variable)
var ok bool if local == nil {
if local, ok = v.currentBlock.Locals[reference.Variable]; !ok {
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
//panic("E")
println("ERROR")
return errors return errors
} }
@ -313,7 +323,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
stmt.Value = expression stmt.Value = expression
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
errors = append(errors, v.validateBlock(&block.Block, functionLocals)...) errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
stmt.Value = block stmt.Value = block
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
@ -335,6 +345,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
} else if v.currentFunction.ReturnType != nil { } else if v.currentFunction.ReturnType != nil {
errors = append(errors, v.createError("missing return value", stmt.Position)) errors = append(errors, v.createError("missing return value", stmt.Position))
} }
stmt.Value = ret
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil { if dlv.Initializer != nil {
@ -355,7 +367,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
ifS := stmt.Value.(IfStatement) ifS := stmt.Value.(IfStatement)
errors = append(errors, v.validateExpression(&ifS.Condition)...) errors = append(errors, v.validateExpression(&ifS.Condition)...)
errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...) errors = append(errors, v.validateBlock(ifS.ConditionalBlock, functionLocals)...)
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
@ -386,6 +398,7 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
for i := range block.Statements { for i := range block.Statements {
v.currentBlock = block v.currentBlock = block
println("CURRBLOCK")
stmt := &block.Statements[i] stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, functionLocals)...) errors = append(errors, v.validateStatement(stmt, functionLocals)...)
} }
@ -400,7 +413,7 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
v.currentFunction = function v.currentFunction = function
body := &function.Body body := function.Body
body.Locals = make(map[string]Local) body.Locals = make(map[string]Local)
for _, param := range function.Parameters { for _, param := range function.Parameters {
local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}