package main import ( "strconv" ) type Validator struct { file *ParsedFile } func isTypeExpandableTo(from Type, to Type) bool { if from.Type == Type_Primitive && to.Type == Type_Primitive { return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) } panic("not implemented") } func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true } switch from { case Primitive_I8, Primitive_U8: if to == Primitive_I16 || to == Primitive_U16 { return true } fallthrough case Primitive_I16, Primitive_U16: if to == Primitive_I32 || to == Primitive_U32 { return true } fallthrough case Primitive_I32, Primitive_U32: if to == Primitive_I64 || to == Primitive_U64 { return true } case Primitive_F32: if to == Primitive_F64 { return true } } return false } func (v *Validator) createError(message string, position uint64) error { // TODO: pass token and get actual token position return CompilerError{Position: position, Message: message} } func (v *Validator) validateImport(imp *Import) []error { // TODO return nil } func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position) } if isPrimitiveTypeExpandableTo(left, right) { return right, nil } if isPrimitiveTypeExpandableTo(right, left) { return left, nil } // TODO: boolean expressions etc. return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error } func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error { var errors []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) var local Local var ok bool if local, ok = block.Locals[assignment.Variable]; !ok { errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors } valueErrors := v.validateExpression(&assignment.Value, block) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors } // TODO: check if assignment is valid expr.ValueType = &local.Type expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) switch literal.Literal.Type { case Literal_Boolean, Literal_Number: expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive} case Literal_String: expr.ValueType = &STRING_TYPE } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) var local Local var ok bool if local, ok = block.Locals[reference.Variable]; !ok { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors } expr.ValueType = &local.Type case Expression_Binary: arithmethic := expr.Value.(BinaryExpression) errors = append(errors, v.validateExpression(&arithmethic.Left, block)...) errors = append(errors, v.validateExpression(&arithmethic.Right, block)...) if len(errors) != 0 { return errors } if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position)) return errors } leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) result, err := v.getArithmeticResultType(expr, leftType, rightType, arithmethic.Operation) if err != nil { errors = append(errors, err) return errors } expr.ValueType = &Type{Type: Type_Primitive, Value: result} expr.Value = arithmethic case Expression_Tuple: tuple := expr.Value.(TupleExpression) var types []Type for i := range tuple.Members { member := &tuple.Members[i] memberErrors := v.validateExpression(member, block) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue } types = append(types, *member.ValueType) } if len(errors) != 0 { return errors } expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) var calledFunc *ParsedFunction = nil for _, f := range v.file.Functions { if f.Name == fc.Function { calledFunc = &f break } } if calledFunc == nil { errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors } if fc.Parameters != nil { paramsErrors := v.validateExpression(fc.Parameters, block) if len(paramsErrors) != 0 { errors = append(errors, paramsErrors...) return errors } params := []Expression{*fc.Parameters} if fc.Parameters.Type == Expression_Tuple { params = fc.Parameters.Value.(TupleExpression).Members } if len(params) != len(calledFunc.Parameters) { errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position)) } for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ { typeGiven := params[i] typeExpected := calledFunc.Parameters[i] if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position)) } } } // TODO: get function and validate using return type expr.ValueType = calledFunc.ReturnType expr.Value = fc case Expression_Negate: neg := expr.Value.(NegateExpression) valErrors := v.validateExpression(&neg.Value, block) if len(valErrors) != 0 { errors = append(errors, valErrors...) return errors } if neg.Value.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("cannot negate non-number types", expr.Position)) } expr.ValueType = neg.Value.ValueType expr.Value = neg } return errors } func (v *Validator) validateExpression(expr *Expression, block *Block) []error { errors := v.validatePotentiallyVoidExpression(expr, block) if len(errors) != 0 { return errors } if expr.ValueType == nil { errors = append(errors, v.createError("expression must not evaluate to void", expr.Position)) } return errors } func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { var errors []error // TODO: support references to variables in parent block switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...) stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, v.validateBlock(&block.Block, functionLocals)...) stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { errors = append(errors, v.validateExpression(ret.Value, block)...) } case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { errors = append(errors, v.validateExpression(dlv.Initializer, block)...) } if _, ok := block.Locals[dlv.Variable]; ok { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} block.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) // TODO: check if assignment of initializer is correct } return errors } func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { block.Locals = make(map[string]Local) } for i := range block.Statements { stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, block, functionLocals)...) } return errors } func (v *Validator) validateFunction(function *ParsedFunction) []error { var errors []error var locals []Local body := &function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} locals = append(locals, local) body.Locals[param.Name] = local } errors = append(errors, v.validateBlock(body, &locals)...) function.Locals = locals return errors } func (v *Validator) validate() []error { var errors []error for i := range v.file.Imports { errors = append(errors, v.validateImport(&v.file.Imports[i])...) } for i := range v.file.Functions { errors = append(errors, v.validateFunction(&v.file.Functions[i])...) } return errors }