diff --git a/backend_wat.go b/backend_wat.go index 307518c..48ecc29 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -529,9 +529,15 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) - wat, err := c.compileExpressionWAT(*ret.Value) - if err != nil { - return "", err + + wat := "" + if ret.Value != nil { + valueWAT, err := c.compileExpressionWAT(*ret.Value) + if err != nil { + return "", err + } + + wat += valueWAT } return wat + "return\n", nil diff --git a/parser.go b/parser.go index cad9b12..ab54a58 100644 --- a/parser.go +++ b/parser.go @@ -47,6 +47,7 @@ type Statement struct { Type StatementType Value any Position TokenPosition + Returns bool } type ExpressionStatement struct { @@ -190,6 +191,7 @@ type Block struct { Parent *Block Statements []Statement Locals map[string]Local + Returns bool } type ParsedParameter struct { diff --git a/types.go b/types.go index a38bf12..9205c09 100644 --- a/types.go +++ b/types.go @@ -1,8 +1,6 @@ package main import ( - "errors" - "slices" "strconv" ) @@ -93,15 +91,6 @@ func getBits(primitiveType PrimitiveType) int { } } -func getPrimitiveTypeByName(name string) (PrimitiveType, error) { - idx := slices.Index(PRIMITIVE_TYPE_NAMES, name) - if idx == -1 { - return InvalidValue, errors.New("not a primitive type name") - } - - return PrimitiveType(idx), nil -} - func getPrimitiveTypeName(primitiveType PrimitiveType) string { return PRIMITIVE_TYPE_NAMES[primitiveType] } @@ -139,15 +128,6 @@ func (t PrimitiveType) String() string { return getPrimitiveTypeName(t) } -func isAssigmentOperator(operator Operator) bool { - switch operator { - case Operator_Equals, Operator_PlusEquals, Operator_MinusEquals, Operator_MultiplyEquals, Operator_DivideEquals, Operator_ModuloEquals: - return true - default: - return false - } -} - func getOperation(operator Operator) Operation { switch operator { case Operator_Greater: diff --git a/validator.go b/validator.go index 1c6ac4d..960efb8 100644 --- a/validator.go +++ b/validator.go @@ -411,9 +411,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, v.validateBlock(block.Block, functionLocals)...) + stmt.Returns = block.Block.Returns stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) + stmt.Returns = true + if ret.Value != nil { if v.CurrentFunction.ReturnType == nil { errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) @@ -470,6 +473,8 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) } + stmt.Returns = ifS.ConditionalBlock.Returns && ifS.ElseBlock != nil && ifS.ElseBlock.Returns + if len(errors) != 0 { return errors } @@ -512,6 +517,10 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error v.CurrentBlock = block stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, functionLocals)...) + + if stmt.Returns { + block.Returns = true + } } return errors @@ -534,7 +543,9 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { errors = append(errors, v.validateBlock(body, &locals)...) - // TODO: validate that function returns return value + if function.ReturnType != nil && !body.Returns { + errors = append(errors, v.createError("function must return a value", function.ReturnType.Position)) + } function.Locals = locals