validate function returns
This commit is contained in:
parent
b9c1ad12c5
commit
8fb09ae880
@ -529,11 +529,17 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
|
|||||||
return wat, nil
|
return wat, nil
|
||||||
case Statement_Return:
|
case Statement_Return:
|
||||||
ret := stmt.Value.(ReturnStatement)
|
ret := stmt.Value.(ReturnStatement)
|
||||||
wat, err := c.compileExpressionWAT(*ret.Value)
|
|
||||||
|
wat := ""
|
||||||
|
if ret.Value != nil {
|
||||||
|
valueWAT, err := c.compileExpressionWAT(*ret.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wat += valueWAT
|
||||||
|
}
|
||||||
|
|
||||||
return wat + "return\n", nil
|
return wat + "return\n", nil
|
||||||
case Statement_DeclareLocalVariable:
|
case Statement_DeclareLocalVariable:
|
||||||
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
||||||
|
@ -47,6 +47,7 @@ type Statement struct {
|
|||||||
Type StatementType
|
Type StatementType
|
||||||
Value any
|
Value any
|
||||||
Position TokenPosition
|
Position TokenPosition
|
||||||
|
Returns bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExpressionStatement struct {
|
type ExpressionStatement struct {
|
||||||
@ -190,6 +191,7 @@ type Block struct {
|
|||||||
Parent *Block
|
Parent *Block
|
||||||
Statements []Statement
|
Statements []Statement
|
||||||
Locals map[string]Local
|
Locals map[string]Local
|
||||||
|
Returns bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ParsedParameter struct {
|
type ParsedParameter struct {
|
||||||
|
20
types.go
20
types.go
@ -1,8 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -93,15 +91,6 @@ func getBits(primitiveType PrimitiveType) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPrimitiveTypeByName(name string) (PrimitiveType, error) {
|
|
||||||
idx := slices.Index(PRIMITIVE_TYPE_NAMES, name)
|
|
||||||
if idx == -1 {
|
|
||||||
return InvalidValue, errors.New("not a primitive type name")
|
|
||||||
}
|
|
||||||
|
|
||||||
return PrimitiveType(idx), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPrimitiveTypeName(primitiveType PrimitiveType) string {
|
func getPrimitiveTypeName(primitiveType PrimitiveType) string {
|
||||||
return PRIMITIVE_TYPE_NAMES[primitiveType]
|
return PRIMITIVE_TYPE_NAMES[primitiveType]
|
||||||
}
|
}
|
||||||
@ -139,15 +128,6 @@ func (t PrimitiveType) String() string {
|
|||||||
return getPrimitiveTypeName(t)
|
return getPrimitiveTypeName(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAssigmentOperator(operator Operator) bool {
|
|
||||||
switch operator {
|
|
||||||
case Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOperation(operator Operator) Operation {
|
func getOperation(operator Operator) Operation {
|
||||||
switch operator {
|
switch operator {
|
||||||
case Operator_Greater:
|
case Operator_Greater:
|
||||||
|
13
validator.go
13
validator.go
@ -411,9 +411,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
|
|||||||
case Statement_Block:
|
case Statement_Block:
|
||||||
block := stmt.Value.(BlockStatement)
|
block := stmt.Value.(BlockStatement)
|
||||||
errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
|
errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
|
||||||
|
stmt.Returns = block.Block.Returns
|
||||||
stmt.Value = block
|
stmt.Value = block
|
||||||
case Statement_Return:
|
case Statement_Return:
|
||||||
ret := stmt.Value.(ReturnStatement)
|
ret := stmt.Value.(ReturnStatement)
|
||||||
|
stmt.Returns = true
|
||||||
|
|
||||||
if ret.Value != nil {
|
if ret.Value != nil {
|
||||||
if v.CurrentFunction.ReturnType == nil {
|
if v.CurrentFunction.ReturnType == nil {
|
||||||
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
|
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
|
||||||
@ -470,6 +473,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
|
|||||||
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
|
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt.Returns = ifS.ConditionalBlock.Returns && ifS.ElseBlock != nil && ifS.ElseBlock.Returns
|
||||||
|
|
||||||
if len(errors) != 0 {
|
if len(errors) != 0 {
|
||||||
return errors
|
return errors
|
||||||
}
|
}
|
||||||
@ -512,6 +517,10 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
|
|||||||
v.CurrentBlock = block
|
v.CurrentBlock = block
|
||||||
stmt := &block.Statements[i]
|
stmt := &block.Statements[i]
|
||||||
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
|
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
|
||||||
|
|
||||||
|
if stmt.Returns {
|
||||||
|
block.Returns = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
@ -534,7 +543,9 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
|
|||||||
|
|
||||||
errors = append(errors, v.validateBlock(body, &locals)...)
|
errors = append(errors, v.validateBlock(body, &locals)...)
|
||||||
|
|
||||||
// TODO: validate that function returns return value
|
if function.ReturnType != nil && !body.Returns {
|
||||||
|
errors = append(errors, v.createError("function must return a value", function.ReturnType.Position))
|
||||||
|
}
|
||||||
|
|
||||||
function.Locals = locals
|
function.Locals = locals
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user