validate function returns

This commit is contained in:
MrLetsplay 2024-04-20 14:50:20 +02:00
parent b9c1ad12c5
commit 8fb09ae880
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 23 additions and 24 deletions

View File

@ -529,11 +529,17 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
return wat, nil
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
wat, err := c.compileExpressionWAT(*ret.Value)
wat := ""
if ret.Value != nil {
valueWAT, err := c.compileExpressionWAT(*ret.Value)
if err != nil {
return "", err
}
wat += valueWAT
}
return wat + "return\n", nil
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)

View File

@ -47,6 +47,7 @@ type Statement struct {
Type StatementType
Value any
Position TokenPosition
Returns bool
}
type ExpressionStatement struct {
@ -190,6 +191,7 @@ type Block struct {
Parent *Block
Statements []Statement
Locals map[string]Local
Returns bool
}
type ParsedParameter struct {

View File

@ -1,8 +1,6 @@
package main
import (
"errors"
"slices"
"strconv"
)
@ -93,15 +91,6 @@ func getBits(primitiveType PrimitiveType) int {
}
}
func getPrimitiveTypeByName(name string) (PrimitiveType, error) {
idx := slices.Index(PRIMITIVE_TYPE_NAMES, name)
if idx == -1 {
return InvalidValue, errors.New("not a primitive type name")
}
return PrimitiveType(idx), nil
}
func getPrimitiveTypeName(primitiveType PrimitiveType) string {
return PRIMITIVE_TYPE_NAMES[primitiveType]
}
@ -139,15 +128,6 @@ func (t PrimitiveType) String() string {
return getPrimitiveTypeName(t)
}
func isAssigmentOperator(operator Operator) bool {
switch operator {
case Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals:
return true
default:
return false
}
}
func getOperation(operator Operator) Operation {
switch operator {
case Operator_Greater:

View File

@ -411,9 +411,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
case Statement_Block:
block := stmt.Value.(BlockStatement)
errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
stmt.Returns = block.Block.Returns
stmt.Value = block
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
stmt.Returns = true
if ret.Value != nil {
if v.CurrentFunction.ReturnType == nil {
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
@ -470,6 +473,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
}
stmt.Returns = ifS.ConditionalBlock.Returns && ifS.ElseBlock != nil && ifS.ElseBlock.Returns
if len(errors) != 0 {
return errors
}
@ -512,6 +517,10 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
v.CurrentBlock = block
stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
if stmt.Returns {
block.Returns = true
}
}
return errors
@ -534,7 +543,9 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
errors = append(errors, v.validateBlock(body, &locals)...)
// TODO: validate that function returns return value
if function.ReturnType != nil && !body.Returns {
errors = append(errors, v.createError("function must return a value", function.ReturnType.Position))
}
function.Locals = locals