package main import ( "errors" ) func createError(message string) error { // TODO: pass token and get actual token position return errors.New(message) } func validateImport(imp *Import) []error { // TODO return nil } func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true } switch from { case Primitive_I8: case Primitive_U8: if to == Primitive_I16 || to == Primitive_U16 { return true } fallthrough case Primitive_I16: case Primitive_U16: if to == Primitive_I32 || to == Primitive_U32 { return true } fallthrough case Primitive_I32: case Primitive_U32: if to == Primitive_I64 || to == Primitive_U64 { return true } case Primitive_F32: if to == Primitive_F64 { return true } } return false } func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { return InvalidValue, createError("bool type cannot be used in arithmetic expressions") } if isTypeExpandableTo(left, right) { return right, nil } if isTypeExpandableTo(right, left) { return left, nil } // TODO: boolean expressions etc. return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error } func validateExpression(expr *Expression, block *Block) []error { var errors []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) var local Local var ok bool if local, ok = block.Locals[assignment.Variable]; !ok { errors = append(errors, createError("Assignment to undeclared variable "+assignment.Variable)) return errors } valueErrors := validateExpression(&assignment.Value, block) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors } // TODO: check if assignment is valid expr.ValueType = local.Type expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) switch literal.Literal.Type { case Literal_Boolean: case Literal_Number: expr.ValueType = Type{Type: Type_Primitive, Value: literal.Literal.Primitive} case Literal_String: expr.ValueType = STRING_TYPE } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) var local Local var ok bool if local, ok = block.Locals[reference.Variable]; !ok { errors = append(errors, createError("Reference to undeclared variable "+reference.Variable)) return errors } expr.ValueType = local.Type expr.Value = reference case Expression_Arithmetic: arithmethic := expr.Value.(ArithmeticExpression) errors = append(errors, validateExpression(&arithmethic.Left, block)...) errors = append(errors, validateExpression(&arithmethic.Right, block)...) if len(errors) != 0 { return errors } // TODO: validate types compatible and determine result type if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type")) return errors } leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) rightType := arithmethic.Left.ValueType.Value.(PrimitiveType) result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation) if err != nil { errors = append(errors, err) return errors } expr.ValueType = Type{Type: Type_Primitive, Value: result} expr.Value = arithmethic case Expression_Tuple: tuple := expr.Value.(TupleExpression) var types []Type for i := range tuple.Members { member := &tuple.Members[i] memberErrors := validateExpression(member, block) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue } types = append(types, member.ValueType) } if len(errors) != 0 { return errors } expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple } return errors } func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { var errors []error // TODO: support references to variables in parent block switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) errors = append(errors, validateExpression(&expression.Expression, block)...) *stmt = Statement{Type: Statement_Expression, Value: expression} case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, validateBlock(&block.Block, functionLocals)...) *stmt = Statement{Type: Statement_Block, Value: block} case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { errors = append(errors, validateExpression(ret.Value, block)...) } case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { errors = append(errors, validateExpression(dlv.Initializer, block)...) } if _, ok := block.Locals[dlv.Variable]; ok { errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'")) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} block.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) } return errors } func validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { block.Locals = make(map[string]Local) } for i := range block.Statements { stmt := &block.Statements[i] errors = append(errors, validateStatement(stmt, block, functionLocals)...) } return errors } func validateFunction(function *ParsedFunction) []error { var errors []error var locals []Local body := &function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} locals = append(locals, local) body.Locals[param.Name] = local } errors = append(errors, validateBlock(body, &locals)...) function.Locals = locals return errors } func validator(file *ParsedFile) []error { var errors []error for i := range file.Imports { errors = append(errors, validateImport(&file.Imports[i])...) } for i := range file.Functions { errors = append(errors, validateFunction(&file.Functions[i])...) } return errors }