311 lines
8.4 KiB
Go
311 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"strconv"
|
|
)
|
|
|
|
type Validator struct {
|
|
file *ParsedFile
|
|
}
|
|
|
|
func isTypeExpandableTo(from Type, to Type) bool {
|
|
if from.Type == Type_Primitive && to.Type == Type_Primitive {
|
|
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
|
|
}
|
|
|
|
panic("not implemented")
|
|
}
|
|
|
|
func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
|
|
if from == to {
|
|
return true
|
|
}
|
|
|
|
switch from {
|
|
case Primitive_I8, Primitive_U8:
|
|
if to == Primitive_I16 || to == Primitive_U16 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I16, Primitive_U16:
|
|
if to == Primitive_I32 || to == Primitive_U32 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I32, Primitive_U32:
|
|
if to == Primitive_I64 || to == Primitive_U64 {
|
|
return true
|
|
}
|
|
|
|
case Primitive_F32:
|
|
if to == Primitive_F64 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (v *Validator) createError(message string) error {
|
|
// TODO: pass token and get actual token position
|
|
return errors.New(message)
|
|
}
|
|
|
|
func (v *Validator) validateImport(imp *Import) []error {
|
|
// TODO
|
|
return nil
|
|
}
|
|
|
|
func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
|
|
if left == Primitive_Bool || right == Primitive_Bool {
|
|
return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions")
|
|
}
|
|
|
|
if isPrimitiveTypeExpandableTo(left, right) {
|
|
return right, nil
|
|
}
|
|
|
|
if isPrimitiveTypeExpandableTo(right, left) {
|
|
return left, nil
|
|
}
|
|
|
|
// TODO: boolean expressions etc.
|
|
|
|
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
|
|
}
|
|
|
|
func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
|
|
var errors []error
|
|
|
|
switch expr.Type {
|
|
case Expression_Assignment:
|
|
assignment := expr.Value.(AssignmentExpression)
|
|
var local Local
|
|
var ok bool
|
|
if local, ok = block.Locals[assignment.Variable]; !ok {
|
|
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable))
|
|
return errors
|
|
}
|
|
|
|
valueErrors := v.validateExpression(&assignment.Value, block)
|
|
if len(valueErrors) != 0 {
|
|
errors = append(errors, valueErrors...)
|
|
return errors
|
|
}
|
|
|
|
// TODO: check if assignment is valid
|
|
expr.ValueType = local.Type
|
|
expr.Value = assignment
|
|
case Expression_Literal:
|
|
literal := expr.Value.(LiteralExpression)
|
|
|
|
switch literal.Literal.Type {
|
|
case Literal_Boolean, Literal_Number:
|
|
expr.ValueType = Type{Type: Type_Primitive, Value: literal.Literal.Primitive}
|
|
case Literal_String:
|
|
expr.ValueType = STRING_TYPE
|
|
}
|
|
case Expression_VariableReference:
|
|
reference := expr.Value.(VariableReferenceExpression)
|
|
var local Local
|
|
var ok bool
|
|
if local, ok = block.Locals[reference.Variable]; !ok {
|
|
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable))
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = local.Type
|
|
case Expression_Arithmetic:
|
|
arithmethic := expr.Value.(ArithmeticExpression)
|
|
|
|
errors = append(errors, v.validateExpression(&arithmethic.Left, block)...)
|
|
errors = append(errors, v.validateExpression(&arithmethic.Right, block)...)
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type"))
|
|
return errors
|
|
}
|
|
|
|
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
|
|
rightType := arithmethic.Right.ValueType.Value.(PrimitiveType)
|
|
result, err := v.getArithmeticResultType(leftType, rightType, arithmethic.Operation)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = Type{Type: Type_Primitive, Value: result}
|
|
expr.Value = arithmethic
|
|
case Expression_Tuple:
|
|
tuple := expr.Value.(TupleExpression)
|
|
|
|
var types []Type
|
|
for i := range tuple.Members {
|
|
member := &tuple.Members[i]
|
|
|
|
memberErrors := v.validateExpression(member, block)
|
|
if len(memberErrors) != 0 {
|
|
errors = append(errors, memberErrors...)
|
|
continue
|
|
}
|
|
|
|
types = append(types, member.ValueType)
|
|
}
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}}
|
|
expr.Value = tuple
|
|
case Expression_FunctionCall:
|
|
fc := expr.Value.(FunctionCallExpression)
|
|
|
|
var calledFunc *ParsedFunction = nil
|
|
for _, f := range v.file.Functions {
|
|
if f.Name == fc.Function {
|
|
calledFunc = &f
|
|
break
|
|
}
|
|
}
|
|
|
|
if calledFunc == nil {
|
|
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'"))
|
|
return errors
|
|
}
|
|
|
|
if fc.Parameters != nil {
|
|
errors = append(errors, v.validateExpression(fc.Parameters, block)...)
|
|
|
|
params := []Expression{*fc.Parameters}
|
|
if fc.Parameters.Type == Expression_Tuple {
|
|
params = fc.Parameters.Value.(TupleExpression).Members
|
|
}
|
|
|
|
if len(params) != len(calledFunc.Parameters) {
|
|
errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params))))
|
|
}
|
|
|
|
for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ {
|
|
typeGiven := params[i]
|
|
typeExpected := calledFunc.Parameters[i]
|
|
if !isTypeExpandableTo(typeGiven.ValueType, typeExpected.Type) {
|
|
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i)))
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO: get function and validate using return type
|
|
expr.ValueType = calledFunc.ReturnType
|
|
expr.Value = fc
|
|
case Expression_Negate:
|
|
neg := expr.Value.(NegateExpression)
|
|
|
|
errors = append(errors, v.validateExpression(&neg.Value, block)...)
|
|
|
|
if neg.Value.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, v.createError("cannot negate non-number types"))
|
|
}
|
|
|
|
expr.ValueType = neg.Value.ValueType
|
|
expr.Value = neg
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
|
|
var errors []error
|
|
|
|
// TODO: support references to variables in parent block
|
|
|
|
switch stmt.Type {
|
|
case Statement_Expression:
|
|
expression := stmt.Value.(ExpressionStatement)
|
|
errors = append(errors, v.validateExpression(&expression.Expression, block)...)
|
|
*stmt = Statement{Type: Statement_Expression, Value: expression}
|
|
case Statement_Block:
|
|
block := stmt.Value.(BlockStatement)
|
|
errors = append(errors, v.validateBlock(&block.Block, functionLocals)...)
|
|
*stmt = Statement{Type: Statement_Block, Value: block}
|
|
case Statement_Return:
|
|
ret := stmt.Value.(ReturnStatement)
|
|
if ret.Value != nil {
|
|
errors = append(errors, v.validateExpression(ret.Value, block)...)
|
|
}
|
|
case Statement_DeclareLocalVariable:
|
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
|
if dlv.Initializer != nil {
|
|
errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
|
|
}
|
|
|
|
if _, ok := block.Locals[dlv.Variable]; ok {
|
|
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'"))
|
|
}
|
|
|
|
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
|
|
block.Locals[dlv.Variable] = local
|
|
*functionLocals = append(*functionLocals, local)
|
|
|
|
// TODO: check if assignment of initializer is correct
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error {
|
|
var errors []error
|
|
|
|
if block.Locals == nil {
|
|
block.Locals = make(map[string]Local)
|
|
}
|
|
|
|
for i := range block.Statements {
|
|
stmt := &block.Statements[i]
|
|
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validateFunction(function *ParsedFunction) []error {
|
|
var errors []error
|
|
|
|
var locals []Local
|
|
|
|
body := &function.Body
|
|
body.Locals = make(map[string]Local)
|
|
for _, param := range function.Parameters {
|
|
local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}
|
|
locals = append(locals, local)
|
|
body.Locals[param.Name] = local
|
|
}
|
|
|
|
errors = append(errors, v.validateBlock(body, &locals)...)
|
|
|
|
function.Locals = locals
|
|
|
|
return errors
|
|
}
|
|
|
|
func (v *Validator) validate() []error {
|
|
var errors []error
|
|
|
|
for i := range v.file.Imports {
|
|
errors = append(errors, v.validateImport(&v.file.Imports[i])...)
|
|
}
|
|
|
|
for i := range v.file.Functions {
|
|
errors = append(errors, v.validateFunction(&v.file.Functions[i])...)
|
|
}
|
|
|
|
return errors
|
|
}
|