validate function returns

This commit is contained in:
MrLetsplay 2024-04-20 14:50:20 +02:00
parent b9c1ad12c5
commit 8fb09ae880
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 23 additions and 24 deletions

View File

@ -529,11 +529,17 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
return wat, nil return wat, nil
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
wat, err := c.compileExpressionWAT(*ret.Value)
wat := ""
if ret.Value != nil {
valueWAT, err := c.compileExpressionWAT(*ret.Value)
if err != nil { if err != nil {
return "", err return "", err
} }
wat += valueWAT
}
return wat + "return\n", nil return wat + "return\n", nil
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)

View File

@ -47,6 +47,7 @@ type Statement struct {
Type StatementType Type StatementType
Value any Value any
Position TokenPosition Position TokenPosition
Returns bool
} }
type ExpressionStatement struct { type ExpressionStatement struct {
@ -190,6 +191,7 @@ type Block struct {
Parent *Block Parent *Block
Statements []Statement Statements []Statement
Locals map[string]Local Locals map[string]Local
Returns bool
} }
type ParsedParameter struct { type ParsedParameter struct {

View File

@ -1,8 +1,6 @@
package main package main
import ( import (
"errors"
"slices"
"strconv" "strconv"
) )
@ -93,15 +91,6 @@ func getBits(primitiveType PrimitiveType) int {
} }
} }
func getPrimitiveTypeByName(name string) (PrimitiveType, error) {
idx := slices.Index(PRIMITIVE_TYPE_NAMES, name)
if idx == -1 {
return InvalidValue, errors.New("not a primitive type name")
}
return PrimitiveType(idx), nil
}
func getPrimitiveTypeName(primitiveType PrimitiveType) string { func getPrimitiveTypeName(primitiveType PrimitiveType) string {
return PRIMITIVE_TYPE_NAMES[primitiveType] return PRIMITIVE_TYPE_NAMES[primitiveType]
} }
@ -139,15 +128,6 @@ func (t PrimitiveType) String() string {
return getPrimitiveTypeName(t) return getPrimitiveTypeName(t)
} }
func isAssigmentOperator(operator Operator) bool {
switch operator {
case Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals:
return true
default:
return false
}
}
func getOperation(operator Operator) Operation { func getOperation(operator Operator) Operation {
switch operator { switch operator {
case Operator_Greater: case Operator_Greater:

View File

@ -411,9 +411,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
errors = append(errors, v.validateBlock(block.Block, functionLocals)...) errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
stmt.Returns = block.Block.Returns
stmt.Value = block stmt.Value = block
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
stmt.Returns = true
if ret.Value != nil { if ret.Value != nil {
if v.CurrentFunction.ReturnType == nil { if v.CurrentFunction.ReturnType == nil {
errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
@ -470,6 +473,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
} }
stmt.Returns = ifS.ConditionalBlock.Returns && ifS.ElseBlock != nil && ifS.ElseBlock.Returns
if len(errors) != 0 { if len(errors) != 0 {
return errors return errors
} }
@ -512,6 +517,10 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
v.CurrentBlock = block v.CurrentBlock = block
stmt := &block.Statements[i] stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, functionLocals)...) errors = append(errors, v.validateStatement(stmt, functionLocals)...)
if stmt.Returns {
block.Returns = true
}
} }
return errors return errors
@ -534,7 +543,9 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
errors = append(errors, v.validateBlock(body, &locals)...) errors = append(errors, v.validateBlock(body, &locals)...)
// TODO: validate that function returns return value if function.ReturnType != nil && !body.Returns {
errors = append(errors, v.createError("function must return a value", function.ReturnType.Position))
}
function.Locals = locals function.Locals = locals