562 lines
16 KiB
Go
562 lines
16 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
)
|
|
|
|
type Validator struct {
|
|
Files []*ParsedFile
|
|
Wasm64 bool
|
|
AllFunctions map[string]*ParsedFunction
|
|
|
|
CurrentBlock *Block
|
|
CurrentFunction *ParsedFunction
|
|
}
|
|
|
|
func isSameType(a Type, b Type) bool {
|
|
if a.Type != b.Type {
|
|
return false
|
|
}
|
|
|
|
switch a.Type {
|
|
case Type_Primitive:
|
|
return a.Value.(PrimitiveType) == b.Value.(PrimitiveType)
|
|
case Type_Named:
|
|
return a.Value.(NamedType).TypeName == b.Value.(NamedType).TypeName
|
|
case Type_Array:
|
|
return isSameType(a.Value.(ArrayType).ElementType, b.Value.(ArrayType).ElementType)
|
|
case Type_Tuple:
|
|
aTuple := a.Value.(TupleType)
|
|
bTuple := b.Value.(TupleType)
|
|
|
|
if len(aTuple.Types) != len(bTuple.Types) {
|
|
return false
|
|
}
|
|
|
|
for i := 0; i < len(aTuple.Types); i++ {
|
|
if !isSameType(aTuple.Types[i], bTuple.Types[i]) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
panic("type not implemented")
|
|
}
|
|
|
|
func isTypeExpandableTo(from Type, to Type) bool {
|
|
if from.Type != to.Type {
|
|
// cannot convert between primitive, named, array and tuple types
|
|
return false
|
|
}
|
|
|
|
if from.Type == Type_Primitive {
|
|
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
|
|
}
|
|
|
|
if from.Type == Type_Tuple {
|
|
fromT := from.Value.(TupleType)
|
|
toT := to.Value.(TupleType)
|
|
|
|
if len(fromT.Types) != len(toT.Types) {
|
|
return false
|
|
}
|
|
|
|
for i := 0; i < len(fromT.Types); i++ {
|
|
if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
if from.Type == Type_Array {
|
|
return isSameType(from.Value.(ArrayType).ElementType, to.Value.(ArrayType).ElementType)
|
|
}
|
|
|
|
panic("not implemented")
|
|
}
|
|
|
|
func expandExpressionToType(expr *Expression, to Type) {
|
|
// TODO: merge with isTypeExpandableTo?
|
|
|
|
if isSameType(*expr.ValueType, to) {
|
|
return
|
|
}
|
|
|
|
if expr.Type == Expression_Tuple {
|
|
tupleExpr := expr.Value.(TupleExpression)
|
|
tupleType := to.Value.(TupleType)
|
|
|
|
for i := 0; i < len(tupleType.Types); i++ {
|
|
expandExpressionToType(&tupleExpr.Members[i], tupleType.Types[i])
|
|
}
|
|
|
|
expr.Value = tupleExpr
|
|
return
|
|
}
|
|
|
|
*expr = Expression{Type: Expression_Cast, Value: CastExpression{Type: to, Value: *expr}, ValueType: &to, Position: expr.Position}
|
|
}
|
|
|
|
func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
|
|
if from == to {
|
|
return true
|
|
}
|
|
|
|
switch from {
|
|
case Primitive_I8, Primitive_U8:
|
|
if to == Primitive_I16 || to == Primitive_U16 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I16, Primitive_U16:
|
|
if to == Primitive_I32 || to == Primitive_U32 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I32, Primitive_U32:
|
|
if to == Primitive_I64 || to == Primitive_U64 {
|
|
return true
|
|
}
|
|
|
|
case Primitive_F32:
|
|
if to == Primitive_F64 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (v *Validator) createError(message string, position TokenPosition) error {
|
|
return CompilerError{Position: position, Message: message}
|
|
}
|
|
|
|
func (v *Validator) validateImport(imp *Import) []error {
|
|
// TODO imports
|
|
return nil
|
|
}
|
|
|
|
func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation Operation) (PrimitiveType, error) {
|
|
if left == Primitive_Bool || right == Primitive_Bool {
|
|
return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position)
|
|
}
|
|
|
|
if isPrimitiveTypeExpandableTo(left, right) {
|
|
return right, nil
|
|
}
|
|
|
|
if isPrimitiveTypeExpandableTo(right, left) {
|
|
return left, nil
|
|
}
|
|
|
|
return InvalidValue, v.createError(fmt.Sprintf("cannot use the types [%s, %s] in an arithmetic expression without an explicit cast", left, right), expr.Position) // TODO: include type names in error
|
|
}
|
|
|
|
func getLocal(block *Block, variable string) *Local {
|
|
if local, ok := block.Locals[variable]; ok {
|
|
return &local
|
|
}
|
|
|
|
if block.Parent == nil {
|
|
return nil
|
|
}
|
|
|
|
return getLocal(block.Parent, variable)
|
|
}
|
|
|
|
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error {
|
|
var errors []error
|
|
|
|
switch expr.Type {
|
|
case Expression_Assignment:
|
|
assignment := expr.Value.(AssignmentExpression)
|
|
|
|
errors = append(errors, v.validateExpression(&assignment.Lhs)...)
|
|
|
|
valueErrors := v.validateExpression(&assignment.Value)
|
|
if len(valueErrors) != 0 {
|
|
errors = append(errors, valueErrors...)
|
|
return errors
|
|
}
|
|
|
|
if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
|
|
if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
|
|
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position))
|
|
}
|
|
|
|
expandExpressionToType(&assignment.Value, *assignment.Lhs.ValueType)
|
|
}
|
|
|
|
expr.ValueType = assignment.Lhs.ValueType
|
|
expr.Value = assignment
|
|
case Expression_Literal:
|
|
literal := expr.Value.(LiteralExpression)
|
|
|
|
switch literal.Literal.Type {
|
|
case Literal_Boolean, Literal_Number:
|
|
expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive}
|
|
case Literal_String:
|
|
expr.ValueType = &STRING_TYPE
|
|
}
|
|
case Expression_VariableReference:
|
|
reference := expr.Value.(VariableReferenceExpression)
|
|
local := getLocal(v.CurrentBlock, reference.Variable)
|
|
if local == nil {
|
|
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = &local.Type
|
|
case Expression_Binary:
|
|
binary := expr.Value.(BinaryExpression)
|
|
|
|
errors = append(errors, v.validateExpression(&binary.Left)...)
|
|
errors = append(errors, v.validateExpression(&binary.Right)...)
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if isBooleanOperation(binary.Operation) {
|
|
if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, v.createError("cannot compare non-primitive types", expr.Position))
|
|
return errors
|
|
}
|
|
|
|
var operandType Type
|
|
if isTypeExpandableTo(*binary.Left.ValueType, *binary.Right.ValueType) {
|
|
operandType = *binary.Right.ValueType
|
|
} else if isTypeExpandableTo(*binary.Right.ValueType, *binary.Left.ValueType) {
|
|
operandType = *binary.Left.ValueType
|
|
} else {
|
|
errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", binary.Left.ValueType.Value.(PrimitiveType), binary.Right.ValueType.Value.(PrimitiveType)), expr.Position))
|
|
return errors
|
|
}
|
|
|
|
expandExpressionToType(&binary.Left, operandType)
|
|
expandExpressionToType(&binary.Right, operandType)
|
|
|
|
expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool}
|
|
}
|
|
|
|
if isArithmeticOperation(binary.Operation) {
|
|
if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, v.createError("both sides of an arithmetic expression must evaluate to a primitive type", expr.Position))
|
|
return errors
|
|
}
|
|
|
|
leftType := binary.Left.ValueType.Value.(PrimitiveType)
|
|
rightType := binary.Right.ValueType.Value.(PrimitiveType)
|
|
result, err := v.getArithmeticResultType(expr, leftType, rightType, binary.Operation)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = &Type{Type: Type_Primitive, Value: result}
|
|
}
|
|
|
|
expr.Value = binary
|
|
case Expression_Tuple:
|
|
tuple := expr.Value.(TupleExpression)
|
|
|
|
var types []Type
|
|
for i := range tuple.Members {
|
|
member := &tuple.Members[i]
|
|
|
|
memberErrors := v.validateExpression(member)
|
|
if len(memberErrors) != 0 {
|
|
errors = append(errors, memberErrors...)
|
|
continue
|
|
}
|
|
|
|
types = append(types, *member.ValueType)
|
|
}
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}}
|
|
expr.Value = tuple
|
|
case Expression_FunctionCall:
|
|
fc := expr.Value.(FunctionCallExpression)
|
|
|
|
calledFunc, ok := v.AllFunctions[fc.Function]
|
|
if !ok {
|
|
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position))
|
|
return errors
|
|
}
|
|
|
|
if fc.Parameters != nil {
|
|
paramsErrors := v.validateExpression(fc.Parameters)
|
|
if len(paramsErrors) != 0 {
|
|
errors = append(errors, paramsErrors...)
|
|
return errors
|
|
}
|
|
|
|
params := []*Expression{fc.Parameters}
|
|
if fc.Parameters.Type == Expression_Tuple {
|
|
for i := 0; i < len(fc.Parameters.Value.(TupleExpression).Members); i++ {
|
|
params[i] = &fc.Parameters.Value.(TupleExpression).Members[i]
|
|
}
|
|
}
|
|
|
|
if len(params) != len(calledFunc.Parameters) {
|
|
errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position))
|
|
}
|
|
|
|
for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ {
|
|
typeGiven := params[i]
|
|
typeExpected := calledFunc.Parameters[i]
|
|
if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) {
|
|
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position))
|
|
}
|
|
|
|
expandExpressionToType(typeGiven, typeExpected.Type)
|
|
}
|
|
}
|
|
|
|
// TODO: get function and validate using return type
|
|
expr.ValueType = calledFunc.ReturnType
|
|
expr.Value = fc
|
|
case Expression_Negate:
|
|
neg := expr.Value.(NegateExpression)
|
|
|
|
valErrors := v.validateExpression(&neg.Value)
|
|
if len(valErrors) != 0 {
|
|
errors = append(errors, valErrors...)
|
|
return errors
|
|
}
|
|
|
|
if neg.Value.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, v.createError("cannot negate non-number types", expr.Position))
|
|
}
|
|
|
|
expr.ValueType = neg.Value.ValueType
|
|
expr.Value = neg
|
|
case Expression_RawMemoryReference:
|
|
raw := expr.Value.(RawMemoryReferenceExpression)
|
|
|
|
addrErrors := v.validateExpression(&raw.Address)
|
|
if len(addrErrors) != 0 {
|
|
errors = append(errors, addrErrors...)
|
|
return errors
|
|
}
|
|
|
|
if raw.Address.ValueType.Type != Type_Primitive || raw.Address.ValueType.Value.(PrimitiveType) != Primitive_U64 {
|
|
errors = append(errors, v.createError("address must evaluate to a u64 value", expr.Position))
|
|
return errors
|
|
}
|
|
|
|
if !v.Wasm64 {
|
|
castTo := Type{Type: Type_Primitive, Value: Primitive_U32}
|
|
raw.Address = Expression{Type: Expression_Cast, Value: CastExpression{Type: castTo, Value: raw.Address}, ValueType: &castTo, Position: raw.Address.Position}
|
|
}
|
|
|
|
expr.ValueType = &raw.Type
|
|
expr.Value = raw
|
|
default:
|
|
panic("expr not implemented")
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateExpression(expr *Expression) []error {
|
|
errors := v.validatePotentiallyVoidExpression(expr)
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if expr.ValueType == nil {
|
|
errors = append(errors, v.createError("expression must not evaluate to void", expr.Position))
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error {
|
|
var errors []error
|
|
|
|
switch stmt.Type {
|
|
case Statement_Expression:
|
|
expression := stmt.Value.(ExpressionStatement)
|
|
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...)
|
|
stmt.Value = expression
|
|
case Statement_Block:
|
|
block := stmt.Value.(BlockStatement)
|
|
errors = append(errors, v.validateBlock(block.Block, functionLocals)...)
|
|
stmt.Value = block
|
|
case Statement_Return:
|
|
ret := stmt.Value.(ReturnStatement)
|
|
if ret.Value != nil {
|
|
if v.CurrentFunction.ReturnType == nil {
|
|
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
|
|
return errors
|
|
}
|
|
|
|
errors = append(errors, v.validateExpression(ret.Value)...)
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if !isTypeExpandableTo(*ret.Value.ValueType, *v.CurrentFunction.ReturnType) {
|
|
errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.CurrentFunction.ReturnType), ret.Value.Position))
|
|
}
|
|
|
|
expandExpressionToType(ret.Value, *v.CurrentFunction.ReturnType)
|
|
} else if v.CurrentFunction.ReturnType != nil {
|
|
errors = append(errors, v.createError("missing return value", stmt.Position))
|
|
}
|
|
|
|
stmt.Value = ret
|
|
case Statement_DeclareLocalVariable:
|
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
|
if dlv.Initializer != nil {
|
|
errors = append(errors, v.validateExpression(dlv.Initializer)...)
|
|
if errors != nil {
|
|
return errors
|
|
}
|
|
|
|
if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) {
|
|
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *dlv.Initializer.ValueType, dlv.VariableType), stmt.Position))
|
|
}
|
|
|
|
expandExpressionToType(dlv.Initializer, dlv.VariableType)
|
|
}
|
|
|
|
if getLocal(v.CurrentBlock, dlv.Variable) != nil {
|
|
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
|
|
}
|
|
|
|
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
|
|
v.CurrentBlock.Locals[dlv.Variable] = local
|
|
*functionLocals = append(*functionLocals, local)
|
|
|
|
stmt.Value = dlv
|
|
case Statement_If:
|
|
ifS := stmt.Value.(IfStatement)
|
|
|
|
errors = append(errors, v.validateExpression(&ifS.Condition)...)
|
|
errors = append(errors, v.validateBlock(ifS.ConditionalBlock, functionLocals)...)
|
|
|
|
if ifS.ElseBlock != nil {
|
|
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
|
|
}
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if ifS.Condition.ValueType.Type != Type_Primitive || ifS.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool {
|
|
errors = append(errors, v.createError("condition must evaluate to boolean", ifS.Condition.Position))
|
|
}
|
|
|
|
stmt.Value = ifS
|
|
case Statement_WhileLoop:
|
|
while := stmt.Value.(WhileLoopStatement)
|
|
|
|
errors = append(errors, v.validateExpression(&while.Condition)...)
|
|
errors = append(errors, v.validateBlock(while.Body, functionLocals)...)
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if while.Condition.ValueType.Type != Type_Primitive || while.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool {
|
|
errors = append(errors, v.createError("condition must evaluate to boolean", while.Condition.Position))
|
|
}
|
|
|
|
stmt.Value = while
|
|
default:
|
|
panic("stmt not implemented")
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error {
|
|
var errors []error
|
|
|
|
if block.Locals == nil {
|
|
block.Locals = make(map[string]Local)
|
|
}
|
|
|
|
for i := range block.Statements {
|
|
v.CurrentBlock = block
|
|
stmt := &block.Statements[i]
|
|
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateFunction(function *ParsedFunction) []error {
|
|
var errors []error
|
|
|
|
var locals []Local
|
|
|
|
v.CurrentFunction = function
|
|
|
|
body := function.Body
|
|
body.Locals = make(map[string]Local)
|
|
for _, param := range function.Parameters {
|
|
local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}
|
|
locals = append(locals, local)
|
|
body.Locals[param.Name] = local
|
|
}
|
|
|
|
errors = append(errors, v.validateBlock(body, &locals)...)
|
|
|
|
// TODO: validate that function returns return value
|
|
|
|
function.Locals = locals
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validate() []error {
|
|
var errors []error
|
|
|
|
v.AllFunctions = make(map[string]*ParsedFunction)
|
|
for _, file := range v.Files {
|
|
for i := range file.Functions {
|
|
function := &file.Functions[i]
|
|
|
|
fullFunctionName := function.Name
|
|
if file.Module != "" {
|
|
fullFunctionName = file.Module + "." + fullFunctionName
|
|
}
|
|
|
|
function.FullName = fullFunctionName
|
|
|
|
if _, exists := v.AllFunctions[fullFunctionName]; exists {
|
|
errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position))
|
|
}
|
|
|
|
v.AllFunctions[fullFunctionName] = function
|
|
}
|
|
}
|
|
|
|
for _, file := range v.Files {
|
|
for i := range file.Imports {
|
|
errors = append(errors, v.validateImport(&file.Imports[i])...)
|
|
}
|
|
|
|
for i := range file.Functions {
|
|
errors = append(errors, v.validateFunction(&file.Functions[i])...)
|
|
}
|
|
}
|
|
|
|
return errors
|
|
}
|