package main import ( "strconv" ) type Validator struct { file *ParsedFile currentBlock *Block currentFunction *ParsedFunction } func isTypeExpandableTo(from Type, to Type) bool { if from.Type != to.Type { // cannot convert between primitive, named, array and tuple types return false } if from.Type == Type_Primitive { return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) } if from.Type == Type_Tuple { fromT := from.Value.(TupleType) toT := to.Value.(TupleType) if len(fromT.Types) != len(toT.Types) { return false } for i := 0; i < len(fromT.Types); i++ { if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) { return false } } return true } panic("not implemented") } func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true } switch from { case Primitive_I8, Primitive_U8: if to == Primitive_I16 || to == Primitive_U16 { return true } fallthrough case Primitive_I16, Primitive_U16: if to == Primitive_I32 || to == Primitive_U32 { return true } fallthrough case Primitive_I32, Primitive_U32: if to == Primitive_I64 || to == Primitive_U64 { return true } case Primitive_F32: if to == Primitive_F64 { return true } } return false } func (v *Validator) createError(message string, position uint64) error { // TODO: pass token and get actual token position return CompilerError{Position: position, Message: message} } func (v *Validator) validateImport(imp *Import) []error { // TODO return nil } func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation Operation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position) } if isPrimitiveTypeExpandableTo(left, right) { return right, nil } if isPrimitiveTypeExpandableTo(right, left) { return left, nil } // TODO: boolean expressions etc. return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error } func getLocal(block *Block, variable string) *Local { if local, ok := block.Locals[variable]; ok { return &local } if block.Parent == nil { return nil } return getLocal(block.Parent, variable) } func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error { var errors []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) local := getLocal(v.currentBlock, assignment.Variable) if local == nil { errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors } valueErrors := v.validateExpression(&assignment.Value) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors } // TODO: check if assignment is valid expr.ValueType = &local.Type expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) switch literal.Literal.Type { case Literal_Boolean, Literal_Number: expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive} case Literal_String: expr.ValueType = &STRING_TYPE } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) local := getLocal(v.currentBlock, reference.Variable) if local == nil { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors } expr.ValueType = &local.Type case Expression_Binary: binary := expr.Value.(BinaryExpression) errors = append(errors, v.validateExpression(&binary.Left)...) errors = append(errors, v.validateExpression(&binary.Right)...) if len(errors) != 0 { return errors } if isBooleanOperation(binary.Operation) { if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("cannot compare non-primitive types", expr.Position)) return errors } leftType := binary.Left.ValueType.Value.(PrimitiveType) rightType := binary.Right.ValueType.Value.(PrimitiveType) var result PrimitiveType = InvalidValue if isPrimitiveTypeExpandableTo(leftType, rightType) { result = leftType } if isPrimitiveTypeExpandableTo(rightType, leftType) { result = leftType } if result == InvalidValue { errors = append(errors, v.createError("cannot compare these types without explicit cast", expr.Position)) return errors } binary.ResultType = &Type{Type: Type_Primitive, Value: result} expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool} } if isArithmeticOperation(binary.Operation) { if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position)) return errors } leftType := binary.Left.ValueType.Value.(PrimitiveType) rightType := binary.Right.ValueType.Value.(PrimitiveType) result, err := v.getArithmeticResultType(expr, leftType, rightType, binary.Operation) if err != nil { errors = append(errors, err) return errors } binary.ResultType = &Type{Type: Type_Primitive, Value: result} expr.ValueType = &Type{Type: Type_Primitive, Value: result} } expr.Value = binary case Expression_Tuple: tuple := expr.Value.(TupleExpression) var types []Type for i := range tuple.Members { member := &tuple.Members[i] memberErrors := v.validateExpression(member) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue } types = append(types, *member.ValueType) } if len(errors) != 0 { return errors } expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) var calledFunc *ParsedFunction = nil for _, f := range v.file.Functions { if f.Name == fc.Function { calledFunc = &f break } } if calledFunc == nil { errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors } if fc.Parameters != nil { paramsErrors := v.validateExpression(fc.Parameters) if len(paramsErrors) != 0 { errors = append(errors, paramsErrors...) return errors } params := []Expression{*fc.Parameters} if fc.Parameters.Type == Expression_Tuple { params = fc.Parameters.Value.(TupleExpression).Members } if len(params) != len(calledFunc.Parameters) { errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position)) } for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ { typeGiven := params[i] typeExpected := calledFunc.Parameters[i] if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position)) } } } // TODO: get function and validate using return type expr.ValueType = calledFunc.ReturnType expr.Value = fc case Expression_Negate: neg := expr.Value.(NegateExpression) valErrors := v.validateExpression(&neg.Value) if len(valErrors) != 0 { errors = append(errors, valErrors...) return errors } if neg.Value.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("cannot negate non-number types", expr.Position)) } expr.ValueType = neg.Value.ValueType expr.Value = neg default: panic("expr not implemented") } return errors } func (v *Validator) validateExpression(expr *Expression) []error { errors := v.validatePotentiallyVoidExpression(expr) if len(errors) != 0 { return errors } if expr.ValueType == nil { errors = append(errors, v.createError("expression must not evaluate to void", expr.Position)) } return errors } func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error { var errors []error // TODO: support references to variables in parent block switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...) stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, v.validateBlock(block.Block, functionLocals)...) stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { if v.currentFunction.ReturnType == nil { errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) return errors } errors = append(errors, v.validateExpression(ret.Value)...) if len(errors) != 0 { return errors } if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) { errors = append(errors, v.createError("expression type does not match function return type", ret.Value.Position)) } } else if v.currentFunction.ReturnType != nil { errors = append(errors, v.createError("missing return value", stmt.Position)) } stmt.Value = ret case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { errors = append(errors, v.validateExpression(dlv.Initializer)...) } if _, ok := v.currentBlock.Locals[dlv.Variable]; ok { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} v.currentBlock.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) // TODO: check if assignment of initializer is correct stmt.Value = dlv case Statement_If: ifS := stmt.Value.(IfStatement) errors = append(errors, v.validateExpression(&ifS.Condition)...) errors = append(errors, v.validateBlock(ifS.ConditionalBlock, functionLocals)...) if ifS.ElseBlock != nil { errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) } if len(errors) != 0 { return errors } if ifS.Condition.ValueType.Type != Type_Primitive || ifS.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool { errors = append(errors, v.createError("condition must evaluate to boolean", ifS.Condition.Position)) } stmt.Value = ifS default: panic("stmt not implemented") } return errors } func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { block.Locals = make(map[string]Local) } for i := range block.Statements { v.currentBlock = block stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, functionLocals)...) } return errors } func (v *Validator) validateFunction(function *ParsedFunction) []error { var errors []error var locals []Local v.currentFunction = function body := function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} locals = append(locals, local) body.Locals[param.Name] = local } errors = append(errors, v.validateBlock(body, &locals)...) function.Locals = locals return errors } func (v *Validator) validate() []error { var errors []error for i := range v.file.Imports { errors = append(errors, v.validateImport(&v.file.Imports[i])...) } for i := range v.file.Functions { errors = append(errors, v.validateFunction(&v.file.Functions[i])...) } return errors }