diff --git a/backend_wat.go b/backend_wat.go index 41a3878..9b2e205 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -333,13 +333,12 @@ func compileFunctionWAT(function ParsedFunction) (string, error) { } // TODO: tuples - returnTypes := []Type{function.ReturnType} - if function.ReturnType.Type == Type_Tuple { - returnTypes = function.ReturnType.Value.(TupleType).Types - } - - if function.ReturnType.Type == Type_Void { - returnTypes = []Type{} + returnTypes := []Type{} + if function.ReturnType != nil { + returnTypes = []Type{*function.ReturnType} + if function.ReturnType.Type == Type_Tuple { + returnTypes = function.ReturnType.Value.(TupleType).Types + } } for _, t := range returnTypes { diff --git a/lexer.go b/lexer.go index 65cc5c2..4d917ea 100644 --- a/lexer.go +++ b/lexer.go @@ -1,7 +1,6 @@ package main import ( - "errors" "slices" "strconv" "strings" @@ -82,12 +81,13 @@ type Literal struct { } type Lexer struct { - Runes []rune - Position uint64 + Runes []rune + LastTokenPosition uint64 + Position uint64 } func (l *Lexer) error(message string) error { - return CompilerError{Position: l.Position, Message: message} + return CompilerError{Position: l.LastTokenPosition, Message: message} } func (l *Lexer) peekRune() *rune { @@ -110,6 +110,8 @@ func (l *Lexer) nextRune() *rune { } func (l *Lexer) stringLiteral() (string, error) { + l.LastTokenPosition = l.Position + openQuote := l.nextRune() if openQuote == nil || *openQuote != '"' { return "", l.error("expected \"") @@ -158,6 +160,8 @@ func (l *Lexer) nextToken() (string, error) { l.nextRune() } + l.LastTokenPosition = l.Position + r := l.peekRune() if r == nil { return "", nil @@ -208,7 +212,7 @@ func parseNumber(raw string, numberType PrimitiveType) (any, error) { func (l *Lexer) parseToken(token string) (*LexToken, error) { if strings.HasPrefix(token, "\"") { - return &LexToken{Type: Type_Literal, Position: l.Position, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: token[1 : len(token)-1]}}, nil + return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: token[1 : len(token)-1]}}, nil } runes := []rune(token) @@ -235,7 +239,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { } if containsDot && !isFloatingPoint(numberType) { - return nil, errors.New("dot in non floating-point constant") + return nil, l.error("dot in non floating-point constant") } number, err := parseNumber(rawNumber, numberType) @@ -243,24 +247,24 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { return nil, err } - return &LexToken{Type: Type_Literal, Position: l.Position, Value: Literal{Type: Literal_Number, Primitive: numberType, Value: number}}, nil + return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_Number, Primitive: numberType, Value: number}}, nil } if len(runes) == 1 { if idx := slices.Index(Separators, runes[0]); idx != -1 { - return &LexToken{Type: Type_Separator, Position: l.Position, Value: Separator(idx)}, nil + return &LexToken{Type: Type_Separator, Position: l.LastTokenPosition, Value: Separator(idx)}, nil } if idx := slices.Index(Operators, runes[0]); idx != -1 { - return &LexToken{Type: Type_Operator, Position: l.Position, Value: Operator(idx)}, nil + return &LexToken{Type: Type_Operator, Position: l.LastTokenPosition, Value: Operator(idx)}, nil } } if idx := slices.Index(Keywords, token); idx != -1 { - return &LexToken{Type: Type_Keyword, Position: l.Position, Value: Keyword(idx)}, nil + return &LexToken{Type: Type_Keyword, Position: l.LastTokenPosition, Value: Keyword(idx)}, nil } - return &LexToken{Type: Type_Identifier, Position: l.Position, Value: token}, nil + return &LexToken{Type: Type_Identifier, Position: l.LastTokenPosition, Value: token}, nil } func lexer(program string) ([]LexToken, error) { diff --git a/parser.go b/parser.go index 03631c6..6c44d99 100644 --- a/parser.go +++ b/parser.go @@ -9,15 +9,15 @@ type TypeType uint32 const ( Type_Primitive TypeType = iota - Type_Void Type_Named Type_Array Type_Tuple ) type Type struct { - Type TypeType - Value any + Type TypeType + Value any + Position uint64 } type NamedType struct { @@ -42,8 +42,9 @@ const ( ) type Statement struct { - Type StatementType - Value any + Type StatementType + Value any + Position uint64 } type ExpressionStatement struct { @@ -79,7 +80,8 @@ const ( type Expression struct { Type ExpressionType Value any - ValueType Type + ValueType *Type + Position uint64 } type AssignmentExpression struct { @@ -145,7 +147,7 @@ type ParsedParameter struct { type ParsedFunction struct { Name string Parameters []ParsedParameter - ReturnType Type + ReturnType *Type Body Block Locals []Local // All of the locals of the function, ordered by their index } @@ -273,22 +275,17 @@ func (p *Parser) tryType() (*Type, error) { return nil, nil } - if tok.Type == Type_Keyword && tok.Value.(Keyword) == Keyword_Void { - *p = pCopy - return &Type{Type: Type_Void, Value: nil}, nil - } - if tok.Type == Type_Identifier { // TODO: array type index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string)) if index != -1 { *p = pCopy - return &Type{Type: Type_Primitive, Value: PrimitiveType(index)}, nil + return &Type{Type: Type_Primitive, Value: PrimitiveType(index), Position: tok.Position}, nil } *p = pCopy - return &Type{Type: Type_Named, Value: tok.Value}, nil + return &Type{Type: Type_Named, Value: tok.Value, Position: tok.Position}, nil } return nil, nil @@ -342,7 +339,7 @@ func (p *Parser) expectTypeOrTupleType() (*Type, error) { return nil, p.error("empty tuple") } - return &Type{Type: Type_Tuple, Value: TupleType{Types: types}}, nil + return &Type{Type: Type_Tuple, Value: TupleType{Types: types}, Position: tok.Position}, nil } t, err := p.tryType() @@ -412,7 +409,7 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { if token.Type == Type_Literal { pCopy.nextToken() *p = pCopy - return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: token.Value.(Literal)}}, nil + return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: token.Value.(Literal)}, Position: token.Position}, nil } if token.Type == Type_Keyword { @@ -420,7 +417,7 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { if keyword == Keyword_True || keyword == KeyWord_False { pCopy.nextToken() *p = pCopy - return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: Literal{Type: Literal_Boolean, Primitive: Primitive_Bool, Value: keyword == Keyword_True}}}, nil + return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: Literal{Type: Literal_Boolean, Primitive: Primitive_Bool, Value: keyword == Keyword_True}}, Position: token.Position}, nil } } @@ -438,7 +435,7 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { } if op == Operator_Minus { - expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}} + expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position} } return expr, nil @@ -466,11 +463,11 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { } *p = pCopy - return &Expression{Type: Expression_FunctionCall, Value: FunctionCallExpression{Function: token.Value.(string), Parameters: params}}, nil + return &Expression{Type: Expression_FunctionCall, Value: FunctionCallExpression{Function: token.Value.(string), Parameters: params}, Position: token.Position}, nil } *p = pCopy - return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}}, nil + return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}, Position: token.Position}, nil } return nil, nil @@ -521,7 +518,7 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { operation = Arithmetic_Mod } - left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}} + left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}, Position: left.Position} } } @@ -561,7 +558,7 @@ func (p *Parser) tryAdditiveExpression() (*Expression, error) { operation = Arithmetic_Sub } - left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}} + left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}, Position: left.Position} } } @@ -603,7 +600,7 @@ func (p *Parser) tryTupleExpression() (*Expression, error) { } *p = pCopy - return &Expression{Type: Expression_Tuple, Value: TupleExpression{Members: members}}, nil + return &Expression{Type: Expression_Tuple, Value: TupleExpression{Members: members}, Position: members[0].Position}, nil } func (p *Parser) expectTupleExpression() (*Expression, error) { @@ -646,7 +643,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) { token := pCopy.nextToken() if token.Type == Type_Separator && token.Value.(Separator) == Separator_Semicolon { *p = pCopy - return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: nil}}, nil + return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: nil}, Position: variableType.Position}, nil } if token.Type != Type_Operator || token.Value.(Operator) != Operator_Equals { @@ -666,7 +663,7 @@ func (p *Parser) tryDeclareLocalVariableStatement() (*Statement, error) { } *p = pCopy - return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: initializer}}, nil + return &Statement{Type: Statement_DeclareLocalVariable, Value: DeclareLocalVariableStatement{Variable: variableName, VariableType: *variableType, Initializer: initializer}, Position: variableType.Position}, nil } func (p *Parser) expectStatement() (*Statement, error) { @@ -685,7 +682,7 @@ func (p *Parser) expectStatement() (*Statement, error) { if token.Type == Type_Separator && token.Value.(Separator) == Separator_Semicolon { p.nextToken() - return &Statement{Type: Statement_Return, Value: ReturnStatement{Value: nil}}, nil + return &Statement{Type: Statement_Return, Value: ReturnStatement{Value: nil}, Position: token.Position}, nil } expr, err := p.expectTupleExpression() @@ -698,7 +695,7 @@ func (p *Parser) expectStatement() (*Statement, error) { return nil, err } - return &Statement{Type: Statement_Return, Value: ReturnStatement{Value: expr}}, nil + return &Statement{Type: Statement_Return, Value: ReturnStatement{Value: expr}, Position: token.Position}, nil } if token.Type == Type_Separator && token.Value.(Separator) == Separator_OpenCurly { @@ -707,7 +704,7 @@ func (p *Parser) expectStatement() (*Statement, error) { return nil, err } - return &Statement{Type: Statement_Block, Value: BlockStatement{Block: *block}}, nil + return &Statement{Type: Statement_Block, Value: BlockStatement{Block: *block}, Position: token.Position}, nil } stmt, err := p.tryDeclareLocalVariableStatement() @@ -730,7 +727,7 @@ func (p *Parser) expectStatement() (*Statement, error) { return nil, err } - return &Statement{Type: Statement_Expression, Value: ExpressionStatement{Expression: *expr}}, nil + return &Statement{Type: Statement_Expression, Value: ExpressionStatement{Expression: *expr}, Position: expr.Position}, nil } return nil, p.error("expected statement") @@ -773,9 +770,15 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) { var returnType *Type var body *Block - returnType, err = p.expectTypeOrTupleType() - if err != nil { - return nil, err + tok := p.peekToken() + if tok.Type == Type_Keyword && tok.Value.(Keyword) == Keyword_Void { + p.nextToken() + returnType = nil + } else { + returnType, err = p.expectTypeOrTupleType() + if err != nil { + return nil, err + } } name, err = p.expectIdentifier() @@ -826,7 +829,7 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) { return nil, err } - return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: *returnType, Body: *body}, nil + return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: *body}, nil } func (p *Parser) parseFile() (*ParsedFile, error) { diff --git a/validator.go b/validator.go index c1cc611..d2119c3 100644 --- a/validator.go +++ b/validator.go @@ -1,7 +1,6 @@ package main import ( - "errors" "strconv" ) @@ -49,9 +48,9 @@ func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { return false } -func (v *Validator) createError(message string) error { +func (v *Validator) createError(message string, position uint64) error { // TODO: pass token and get actual token position - return errors.New(message) + return CompilerError{Position: position, Message: message} } func (v *Validator) validateImport(imp *Import) []error { @@ -59,9 +58,9 @@ func (v *Validator) validateImport(imp *Import) []error { return nil } -func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { +func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { - return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions") + return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position) } if isPrimitiveTypeExpandableTo(left, right) { @@ -74,10 +73,10 @@ func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveT // TODO: boolean expressions etc. - return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error + return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error } -func (v *Validator) validateExpression(expr *Expression, block *Block) []error { +func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error { var errors []error switch expr.Type { @@ -86,7 +85,7 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { var local Local var ok bool if local, ok = block.Locals[assignment.Variable]; !ok { - errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable)) + errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors } @@ -97,27 +96,27 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { } // TODO: check if assignment is valid - expr.ValueType = local.Type + expr.ValueType = &local.Type expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) switch literal.Literal.Type { case Literal_Boolean, Literal_Number: - expr.ValueType = Type{Type: Type_Primitive, Value: literal.Literal.Primitive} + expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive} case Literal_String: - expr.ValueType = STRING_TYPE + expr.ValueType = &STRING_TYPE } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) var local Local var ok bool if local, ok = block.Locals[reference.Variable]; !ok { - errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable)) + errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors } - expr.ValueType = local.Type + expr.ValueType = &local.Type case Expression_Arithmetic: arithmethic := expr.Value.(ArithmeticExpression) @@ -129,19 +128,19 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { } if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { - errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type")) + errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position)) return errors } leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) - result, err := v.getArithmeticResultType(leftType, rightType, arithmethic.Operation) + result, err := v.getArithmeticResultType(expr, leftType, rightType, arithmethic.Operation) if err != nil { errors = append(errors, err) return errors } - expr.ValueType = Type{Type: Type_Primitive, Value: result} + expr.ValueType = &Type{Type: Type_Primitive, Value: result} expr.Value = arithmethic case Expression_Tuple: tuple := expr.Value.(TupleExpression) @@ -156,14 +155,14 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { continue } - types = append(types, member.ValueType) + types = append(types, *member.ValueType) } if len(errors) != 0 { return errors } - expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}} + expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) @@ -177,12 +176,16 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { } if calledFunc == nil { - errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'")) + errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors } if fc.Parameters != nil { - errors = append(errors, v.validateExpression(fc.Parameters, block)...) + paramsErrors := v.validateExpression(fc.Parameters, block) + if len(paramsErrors) > 0 { + errors = append(errors, paramsErrors...) + return errors + } params := []Expression{*fc.Parameters} if fc.Parameters.Type == Expression_Tuple { @@ -190,14 +193,14 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { } if len(params) != len(calledFunc.Parameters) { - errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)))) + errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position)) } for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ { typeGiven := params[i] typeExpected := calledFunc.Parameters[i] - if !isTypeExpandableTo(typeGiven.ValueType, typeExpected.Type) { - errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i))) + if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { + errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position)) } } } @@ -208,10 +211,14 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { case Expression_Negate: neg := expr.Value.(NegateExpression) - errors = append(errors, v.validateExpression(&neg.Value, block)...) + valErrors := v.validateExpression(&neg.Value, block) + if len(valErrors) > 0 { + errors = append(errors, valErrors...) + return errors + } if neg.Value.ValueType.Type != Type_Primitive { - errors = append(errors, v.createError("cannot negate non-number types")) + errors = append(errors, v.createError("cannot negate non-number types", expr.Position)) } expr.ValueType = neg.Value.ValueType @@ -221,6 +228,16 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error { return errors } +func (v *Validator) validateExpression(expr *Expression, block *Block) []error { + errors := v.validatePotentiallyVoidExpression(expr, block) + + if expr.ValueType == nil { + errors = append(errors, v.createError("expression must not evaluate to void", expr.Position)) + } + + return errors +} + func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { var errors []error @@ -229,12 +246,12 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) - errors = append(errors, v.validateExpression(&expression.Expression, block)...) - *stmt = Statement{Type: Statement_Expression, Value: expression} + errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...) + stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, v.validateBlock(&block.Block, functionLocals)...) - *stmt = Statement{Type: Statement_Block, Value: block} + stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { @@ -247,7 +264,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc } if _, ok := block.Locals[dlv.Variable]; ok { - errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'")) + errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}