package main import ( "fmt" "strconv" ) type Validator struct { Files []*ParsedFile Wasm64 bool AllFunctions map[string]*ParsedFunction CurrentBlock *Block CurrentFunction *ParsedFunction } const ( BUILTIN_MEMORY_GROW = "__builtin_memory_grow" BUILTIN_MEMORY_SIZE = "__builtin_memory_size" ) var builtinFunctions map[string]*ParsedFunction = map[string]*ParsedFunction{ BUILTIN_MEMORY_GROW: { Name: BUILTIN_MEMORY_GROW, FullName: "builtin." + BUILTIN_MEMORY_GROW, Parameters: []ParsedParameter{{Name: "memory", Type: Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}}}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_I64, Position: unknownPosition()}, }, BUILTIN_MEMORY_SIZE: { Name: BUILTIN_MEMORY_SIZE, FullName: "builtin." + BUILTIN_MEMORY_SIZE, Parameters: []ParsedParameter{}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}, }, } func isSameType(a Type, b Type) bool { if a.Type != b.Type { return false } switch a.Type { case Type_Primitive: return a.Value.(PrimitiveType) == b.Value.(PrimitiveType) case Type_Named: return a.Value.(NamedType).TypeName == b.Value.(NamedType).TypeName case Type_Array: return isSameType(a.Value.(ArrayType).ElementType, b.Value.(ArrayType).ElementType) case Type_Tuple: aTuple := a.Value.(TupleType) bTuple := b.Value.(TupleType) if len(aTuple.Types) != len(bTuple.Types) { return false } for i := 0; i < len(aTuple.Types); i++ { if !isSameType(aTuple.Types[i], bTuple.Types[i]) { return false } } return true } panic("type not implemented") } func isTypeExpandableTo(from Type, to Type) bool { if from.Type != to.Type { // cannot convert between primitive, named, array and tuple types return false } if from.Type == Type_Primitive { return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) } if from.Type == Type_Tuple { fromT := from.Value.(TupleType) toT := to.Value.(TupleType) if len(fromT.Types) != len(toT.Types) { return false } for i := 0; i < len(fromT.Types); i++ { if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) { return false } } return true } if from.Type == Type_Array { return isSameType(from.Value.(ArrayType).ElementType, to.Value.(ArrayType).ElementType) } panic("not implemented") } func expandExpressionToType(expr *Expression, to Type) { // TODO: merge with isTypeExpandableTo? if isSameType(*expr.ValueType, to) { return } if expr.Type == Expression_Tuple { tupleExpr := expr.Value.(TupleExpression) tupleType := to.Value.(TupleType) for i := 0; i < len(tupleType.Types); i++ { expandExpressionToType(&tupleExpr.Members[i], tupleType.Types[i]) } expr.Value = tupleExpr return } *expr = Expression{Type: Expression_Cast, Value: CastExpression{Type: to, Value: *expr}, ValueType: &to, Position: expr.Position} } func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true } if isSignedInt(from) && !isSignedInt(to) { return false } switch from { case Primitive_I8, Primitive_U8: if to == Primitive_I16 || to == Primitive_U16 { return true } fallthrough case Primitive_I16, Primitive_U16: if to == Primitive_I32 || to == Primitive_U32 { return true } fallthrough case Primitive_I32, Primitive_U32: if to == Primitive_I64 || to == Primitive_U64 { return true } case Primitive_F32: if to == Primitive_F64 { return true } } return false } func (v *Validator) createError(message string, position TokenPosition) error { return CompilerError{Position: position, Message: message} } func (v *Validator) validateImport(imp *Import) []error { // TODO imports return nil } func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation Operation) (PrimitiveType, error) { if left == Primitive_Bool || right == Primitive_Bool { return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position) } if isPrimitiveTypeExpandableTo(left, right) { return right, nil } if isPrimitiveTypeExpandableTo(right, left) { return left, nil } return InvalidValue, v.createError(fmt.Sprintf("cannot use the types [%s, %s] in an arithmetic expression without an explicit cast", left, right), expr.Position) } func getLocal(block *Block, variable string) *Local { if local, ok := block.Locals[variable]; ok { return &local } if block.Parent == nil { return nil } return getLocal(block.Parent, variable) } func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error { var errors []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) errors = append(errors, v.validateExpression(&assignment.Lhs)...) valueErrors := v.validateExpression(&assignment.Value) if len(valueErrors) != 0 { errors = append(errors, valueErrors...) return errors } if assignment.Operation != Operation_Equals && (assignment.Lhs.ValueType.Type != Type_Primitive || assignment.Value.ValueType.Type != Type_Primitive) { errors = append(errors, v.createError("both sides of an arithmetic expression must evaluate to a primitive type", expr.Position)) return errors } if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position)) } expandExpressionToType(&assignment.Value, *assignment.Lhs.ValueType) } expr.ValueType = assignment.Lhs.ValueType expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) switch literal.Literal.Type { case Literal_Boolean, Literal_Number: expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive} case Literal_String: expr.ValueType = &STRING_TYPE } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) local := getLocal(v.CurrentBlock, reference.Variable) if local == nil { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors } expr.ValueType = &local.Type case Expression_Binary: binary := expr.Value.(BinaryExpression) errors = append(errors, v.validateExpression(&binary.Left)...) errors = append(errors, v.validateExpression(&binary.Right)...) if len(errors) != 0 { return errors } if isBooleanOperation(binary.Operation) { if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("cannot compare non-primitive types", expr.Position)) return errors } var operandType Type if isTypeExpandableTo(*binary.Left.ValueType, *binary.Right.ValueType) { operandType = *binary.Right.ValueType } else if isTypeExpandableTo(*binary.Right.ValueType, *binary.Left.ValueType) { operandType = *binary.Left.ValueType } else { errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", binary.Left.ValueType.Value.(PrimitiveType), binary.Right.ValueType.Value.(PrimitiveType)), expr.Position)) return errors } expandExpressionToType(&binary.Left, operandType) expandExpressionToType(&binary.Right, operandType) expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool} } if isArithmeticOperation(binary.Operation) { if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("both sides of an arithmetic expression must evaluate to a primitive type", expr.Position)) return errors } leftType := binary.Left.ValueType.Value.(PrimitiveType) rightType := binary.Right.ValueType.Value.(PrimitiveType) result, err := v.getArithmeticResultType(expr, leftType, rightType, binary.Operation) if err != nil { errors = append(errors, err) return errors } expr.ValueType = &Type{Type: Type_Primitive, Value: result} } expr.Value = binary case Expression_Tuple: tuple := expr.Value.(TupleExpression) var types []Type for i := range tuple.Members { member := &tuple.Members[i] memberErrors := v.validateExpression(member) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue } types = append(types, *member.ValueType) } if len(errors) != 0 { return errors } expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.Value = tuple case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) calledFunc, ok := builtinFunctions[fc.Function] if !ok { calledFunc, ok = v.AllFunctions[fc.Function] if !ok { errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors } } if fc.Parameters != nil { paramsErrors := v.validateExpression(fc.Parameters) if len(paramsErrors) != 0 { errors = append(errors, paramsErrors...) return errors } } var params []*Expression if fc.Parameters == nil { params = []*Expression{} } else if fc.Parameters.Type == Expression_Tuple { params = make([]*Expression, len(fc.Parameters.Value.(TupleExpression).Members)) for i := 0; i < len(fc.Parameters.Value.(TupleExpression).Members); i++ { params[i] = &fc.Parameters.Value.(TupleExpression).Members[i] } } else { params = []*Expression{fc.Parameters} } if len(params) != len(calledFunc.Parameters) { errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position)) } for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ { typeGiven := params[i] typeExpected := calledFunc.Parameters[i] if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i)+": expected "+typeExpected.Type.String()+", got "+typeGiven.ValueType.String(), expr.Position)) } expandExpressionToType(typeGiven, typeExpected.Type) } expr.ValueType = calledFunc.ReturnType expr.Value = fc case Expression_Unary: unary := expr.Value.(UnaryExpression) valErrors := v.validateExpression(&unary.Value) if len(valErrors) != 0 { errors = append(errors, valErrors...) return errors } if unary.Value.ValueType.Type != Type_Primitive { errors = append(errors, v.createError("cannot operate on non-primitive types", expr.Position)) } else { primitive := unary.Value.ValueType.Value.(PrimitiveType) if (unary.Operation == UnaryOperation_Negate || unary.Operation == UnaryOperation_Nop /* + sign */) && !isSignedInt(primitive) && !isFloatingPoint(primitive) { errors = append(errors, v.createError("can only perform negation/unary plus on signed types", expr.Position)) } if unary.Operation == UnaryOperation_LogicalNot && primitive != Primitive_Bool { errors = append(errors, v.createError("cannot perform logical not on non-bool type", expr.Position)) } if unary.Operation == UnaryOperation_BitwiseNot && !isUnsignedInt(primitive) && !isSignedInt(primitive) { errors = append(errors, v.createError("cannot perform bitwise not on non-integer type", expr.Position)) } } expr.ValueType = unary.Value.ValueType expr.Value = unary case Expression_RawMemoryReference: raw := expr.Value.(RawMemoryReferenceExpression) addrErrors := v.validateExpression(&raw.Address) if len(addrErrors) != 0 { errors = append(errors, addrErrors...) return errors } typeU64 := Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()} if !isTypeExpandableTo(*raw.Address.ValueType, typeU64) { errors = append(errors, v.createError("address must be expandable to a u64 value", expr.Position)) return errors } expandExpressionToType(&raw.Address, typeU64) if !v.Wasm64 { castTo := Type{Type: Type_Primitive, Value: Primitive_U32} raw.Address = Expression{Type: Expression_Cast, Value: CastExpression{Type: castTo, Value: raw.Address}, ValueType: &castTo, Position: raw.Address.Position} } expr.ValueType = &raw.Type expr.Value = raw case Expression_ArrayAccess: array := expr.Value.(ArrayAccessExpression) arrayErrors := v.validateExpression(&array.Array) if len(arrayErrors) != 0 { errors = append(errors, arrayErrors...) } indexErrors := v.validateExpression(&array.Index) if len(indexErrors) != 0 { errors = append(errors, indexErrors...) return errors } if len(errors) != 0 { return errors } if array.Array.ValueType.Type != Type_Array { errors = append(errors, v.createError("trying to access non-array type as an array", array.Array.Position)) } typeI64 := Type{Type: Type_Primitive, Value: Primitive_I64, Position: unknownPosition()} if !isTypeExpandableTo(*array.Index.ValueType, typeI64) { errors = append(errors, v.createError("array index must be expandable to an i64 value", array.Index.Position)) return errors } expandExpressionToType(&array.Index, typeI64) if len(errors) != 0 { return errors } elementType := array.Array.ValueType.Value.(ArrayType).ElementType expr.ValueType = &elementType expr.Value = array default: panic("expr not implemented") } return errors } func (v *Validator) validateExpression(expr *Expression) []error { errors := v.validatePotentiallyVoidExpression(expr) if len(errors) != 0 { return errors } if expr.ValueType == nil { errors = append(errors, v.createError("expression must not evaluate to void", expr.Position)) } return errors } func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error { var errors []error switch stmt.Type { case Statement_Expression: expression := stmt.Value.(ExpressionStatement) errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...) stmt.Value = expression case Statement_Block: block := stmt.Value.(BlockStatement) errors = append(errors, v.validateBlock(block.Block, functionLocals)...) stmt.Returns = block.Block.Returns stmt.Value = block case Statement_Return: ret := stmt.Value.(ReturnStatement) stmt.Returns = true if ret.Value != nil { if v.CurrentFunction.ReturnType == nil { errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) return errors } errors = append(errors, v.validateExpression(ret.Value)...) if len(errors) != 0 { return errors } if !isTypeExpandableTo(*ret.Value.ValueType, *v.CurrentFunction.ReturnType) { errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.CurrentFunction.ReturnType), ret.Value.Position)) } expandExpressionToType(ret.Value, *v.CurrentFunction.ReturnType) } else if v.CurrentFunction.ReturnType != nil { errors = append(errors, v.createError("missing return value", stmt.Position)) } stmt.Value = ret case Statement_DeclareLocalVariable: dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { errors = append(errors, v.validateExpression(dlv.Initializer)...) if errors != nil { return errors } if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) { errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *dlv.Initializer.ValueType, dlv.VariableType), stmt.Position)) } expandExpressionToType(dlv.Initializer, dlv.VariableType) } if getLocal(v.CurrentBlock, dlv.Variable) != nil { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} v.CurrentBlock.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) stmt.Value = dlv case Statement_If: ifS := stmt.Value.(IfStatement) errors = append(errors, v.validateExpression(&ifS.Condition)...) errors = append(errors, v.validateBlock(ifS.ConditionalBlock, functionLocals)...) if ifS.ElseBlock != nil { errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...) } stmt.Returns = ifS.ConditionalBlock.Returns && ifS.ElseBlock != nil && ifS.ElseBlock.Returns if len(errors) != 0 { return errors } if ifS.Condition.ValueType.Type != Type_Primitive || ifS.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool { errors = append(errors, v.createError("condition must evaluate to boolean", ifS.Condition.Position)) } stmt.Value = ifS case Statement_WhileLoop: while := stmt.Value.(WhileLoopStatement) errors = append(errors, v.validateExpression(&while.Condition)...) errors = append(errors, v.validateBlock(while.Body, functionLocals)...) if len(errors) != 0 { return errors } if while.Condition.ValueType.Type != Type_Primitive || while.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool { errors = append(errors, v.createError("condition must evaluate to boolean", while.Condition.Position)) } stmt.Value = while default: panic("stmt not implemented") } return errors } func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error { var errors []error if block.Locals == nil { block.Locals = make(map[string]Local) } for i := range block.Statements { v.CurrentBlock = block stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, functionLocals)...) if stmt.Returns { block.Returns = true } } return errors } func (v *Validator) validateFunction(function *ParsedFunction) []error { var errors []error var locals []Local v.CurrentFunction = function body := function.Body body.Locals = make(map[string]Local) for _, param := range function.Parameters { local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)} locals = append(locals, local) body.Locals[param.Name] = local } errors = append(errors, v.validateBlock(body, &locals)...) if function.ReturnType != nil && !body.Returns { errors = append(errors, v.createError("function must return a value", function.ReturnType.Position)) } function.Locals = locals return errors } func (v *Validator) validate() []error { var errors []error v.AllFunctions = make(map[string]*ParsedFunction) for _, file := range v.Files { for i := range file.Functions { function := &file.Functions[i] fullFunctionName := function.Name if file.Module != "" { fullFunctionName = file.Module + "." + fullFunctionName } function.FullName = fullFunctionName if _, exists := v.AllFunctions[fullFunctionName]; exists { errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) } v.AllFunctions[fullFunctionName] = function } } for _, file := range v.Files { for i := range file.Imports { errors = append(errors, v.validateImport(&file.Imports[i])...) } for i := range file.Functions { errors = append(errors, v.validateFunction(&file.Functions[i])...) } } return errors }