elysium/validator.go

386 lines
11 KiB
Go
Raw Normal View History

2024-03-16 20:12:00 +01:00
package main
import (
"strconv"
2024-03-16 20:12:00 +01:00
)
type Validator struct {
file *ParsedFile
2024-03-16 20:12:00 +01:00
}
func isTypeExpandableTo(from Type, to Type) bool {
if from.Type == Type_Primitive && to.Type == Type_Primitive {
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
}
panic("not implemented")
2024-03-16 20:12:00 +01:00
}
func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
2024-03-17 19:55:28 +01:00
if from == to {
return true
}
switch from {
2024-03-19 12:19:19 +01:00
case Primitive_I8, Primitive_U8:
2024-03-17 19:55:28 +01:00
if to == Primitive_I16 || to == Primitive_U16 {
return true
}
fallthrough
2024-03-19 12:19:19 +01:00
case Primitive_I16, Primitive_U16:
2024-03-17 19:55:28 +01:00
if to == Primitive_I32 || to == Primitive_U32 {
return true
}
fallthrough
2024-03-19 12:19:19 +01:00
case Primitive_I32, Primitive_U32:
2024-03-17 19:55:28 +01:00
if to == Primitive_I64 || to == Primitive_U64 {
return true
}
case Primitive_F32:
if to == Primitive_F64 {
return true
}
}
return false
}
2024-03-21 19:55:05 +01:00
func (v *Validator) createError(message string, position uint64) error {
// TODO: pass token and get actual token position
2024-03-21 19:55:05 +01:00
return CompilerError{Position: position, Message: message}
}
func (v *Validator) validateImport(imp *Import) []error {
// TODO
return nil
}
2024-03-24 14:01:23 +01:00
func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType, right PrimitiveType, operation Operation) (PrimitiveType, error) {
2024-03-17 19:55:28 +01:00
if left == Primitive_Bool || right == Primitive_Bool {
2024-03-21 19:55:05 +01:00
return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions", expr.Position)
2024-03-17 19:55:28 +01:00
}
if isPrimitiveTypeExpandableTo(left, right) {
2024-03-17 19:55:28 +01:00
return right, nil
}
if isPrimitiveTypeExpandableTo(right, left) {
2024-03-17 19:55:28 +01:00
return left, nil
}
// TODO: boolean expressions etc.
2024-03-21 19:55:05 +01:00
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error
2024-03-17 19:55:28 +01:00
}
2024-03-21 19:55:05 +01:00
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error {
2024-03-16 20:12:00 +01:00
var errors []error
switch expr.Type {
case Expression_Assignment:
assignment := expr.Value.(AssignmentExpression)
var local Local
var ok bool
if local, ok = block.Locals[assignment.Variable]; !ok {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
2024-03-16 20:12:00 +01:00
return errors
}
valueErrors := v.validateExpression(&assignment.Value, block)
2024-03-16 20:12:00 +01:00
if len(valueErrors) != 0 {
errors = append(errors, valueErrors...)
return errors
}
// TODO: check if assignment is valid
2024-03-21 19:55:05 +01:00
expr.ValueType = &local.Type
2024-03-19 10:54:21 +01:00
expr.Value = assignment
2024-03-16 20:12:00 +01:00
case Expression_Literal:
literal := expr.Value.(LiteralExpression)
switch literal.Literal.Type {
2024-03-19 12:48:06 +01:00
case Literal_Boolean, Literal_Number:
2024-03-21 19:55:05 +01:00
expr.ValueType = &Type{Type: Type_Primitive, Value: literal.Literal.Primitive}
2024-03-16 20:12:00 +01:00
case Literal_String:
2024-03-21 19:55:05 +01:00
expr.ValueType = &STRING_TYPE
2024-03-16 20:12:00 +01:00
}
case Expression_VariableReference:
reference := expr.Value.(VariableReferenceExpression)
var local Local
var ok bool
if local, ok = block.Locals[reference.Variable]; !ok {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
2024-03-16 20:12:00 +01:00
return errors
}
2024-03-21 19:55:05 +01:00
expr.ValueType = &local.Type
2024-03-23 14:03:20 +01:00
case Expression_Binary:
2024-03-24 14:01:23 +01:00
binary := expr.Value.(BinaryExpression)
2024-03-16 20:12:00 +01:00
2024-03-24 14:01:23 +01:00
errors = append(errors, v.validateExpression(&binary.Left, block)...)
errors = append(errors, v.validateExpression(&binary.Right, block)...)
2024-03-16 20:12:00 +01:00
if len(errors) != 0 {
return errors
}
2024-03-24 14:01:23 +01:00
if isBooleanOperation(binary.Operation) {
if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive {
errors = append(errors, v.createError("cannot compare non-primitive types", expr.Position))
return errors
}
leftType := binary.Left.ValueType.Value.(PrimitiveType)
rightType := binary.Right.ValueType.Value.(PrimitiveType)
var result PrimitiveType = InvalidValue
if isPrimitiveTypeExpandableTo(leftType, rightType) {
result = leftType
}
if isPrimitiveTypeExpandableTo(rightType, leftType) {
result = leftType
}
if result == InvalidValue {
errors = append(errors, v.createError("cannot compare these types without explicit cast", expr.Position))
return errors
}
binary.ResultType = &Type{Type: Type_Primitive, Value: result}
expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool}
2024-03-17 19:55:28 +01:00
}
2024-03-24 14:01:23 +01:00
if isArithmeticOperation(binary.Operation) {
if binary.Left.ValueType.Type != Type_Primitive || binary.Right.ValueType.Type != Type_Primitive {
errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type", expr.Position))
return errors
}
leftType := binary.Left.ValueType.Value.(PrimitiveType)
rightType := binary.Right.ValueType.Value.(PrimitiveType)
result, err := v.getArithmeticResultType(expr, leftType, rightType, binary.Operation)
if err != nil {
errors = append(errors, err)
return errors
}
binary.ResultType = &Type{Type: Type_Primitive, Value: result}
expr.ValueType = &Type{Type: Type_Primitive, Value: result}
2024-03-17 19:55:28 +01:00
}
2024-03-24 14:01:23 +01:00
expr.Value = binary
2024-03-16 20:12:00 +01:00
case Expression_Tuple:
tuple := expr.Value.(TupleExpression)
var types []Type
2024-03-17 19:55:28 +01:00
for i := range tuple.Members {
member := &tuple.Members[i]
memberErrors := v.validateExpression(member, block)
2024-03-16 20:12:00 +01:00
if len(memberErrors) != 0 {
errors = append(errors, memberErrors...)
continue
}
2024-03-21 19:55:05 +01:00
types = append(types, *member.ValueType)
2024-03-16 20:12:00 +01:00
}
if len(errors) != 0 {
return errors
}
2024-03-21 19:55:05 +01:00
expr.ValueType = &Type{Type: Type_Tuple, Value: TupleType{Types: types}}
2024-03-19 10:54:21 +01:00
expr.Value = tuple
case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression)
var calledFunc *ParsedFunction = nil
for _, f := range v.file.Functions {
if f.Name == fc.Function {
calledFunc = &f
break
}
}
if calledFunc == nil {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position))
return errors
}
if fc.Parameters != nil {
2024-03-21 19:55:05 +01:00
paramsErrors := v.validateExpression(fc.Parameters, block)
2024-03-21 20:37:21 +01:00
if len(paramsErrors) != 0 {
2024-03-21 19:55:05 +01:00
errors = append(errors, paramsErrors...)
return errors
}
params := []Expression{*fc.Parameters}
if fc.Parameters.Type == Expression_Tuple {
params = fc.Parameters.Value.(TupleExpression).Members
}
if len(params) != len(calledFunc.Parameters) {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params)), expr.Position))
}
for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ {
typeGiven := params[i]
typeExpected := calledFunc.Parameters[i]
2024-03-21 19:55:05 +01:00
if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) {
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position))
}
}
}
// TODO: get function and validate using return type
expr.ValueType = calledFunc.ReturnType
expr.Value = fc
case Expression_Negate:
neg := expr.Value.(NegateExpression)
2024-03-21 19:55:05 +01:00
valErrors := v.validateExpression(&neg.Value, block)
2024-03-21 20:37:21 +01:00
if len(valErrors) != 0 {
2024-03-21 19:55:05 +01:00
errors = append(errors, valErrors...)
return errors
}
if neg.Value.ValueType.Type != Type_Primitive {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("cannot negate non-number types", expr.Position))
}
expr.ValueType = neg.Value.ValueType
expr.Value = neg
2024-03-24 14:01:23 +01:00
default:
panic("expr not implemented")
2024-03-16 20:12:00 +01:00
}
return errors
}
2024-03-21 19:55:05 +01:00
func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
errors := v.validatePotentiallyVoidExpression(expr, block)
2024-03-21 20:37:21 +01:00
if len(errors) != 0 {
return errors
}
2024-03-21 19:55:05 +01:00
if expr.ValueType == nil {
errors = append(errors, v.createError("expression must not evaluate to void", expr.Position))
}
return errors
}
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
2024-03-16 20:12:00 +01:00
var errors []error
2024-03-18 21:14:28 +01:00
// TODO: support references to variables in parent block
2024-03-16 20:12:00 +01:00
switch stmt.Type {
case Statement_Expression:
expression := stmt.Value.(ExpressionStatement)
2024-03-21 19:55:05 +01:00
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...)
stmt.Value = expression
2024-03-16 20:12:00 +01:00
case Statement_Block:
block := stmt.Value.(BlockStatement)
errors = append(errors, v.validateBlock(&block.Block, functionLocals)...)
2024-03-21 19:55:05 +01:00
stmt.Value = block
2024-03-16 20:12:00 +01:00
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
if ret.Value != nil {
errors = append(errors, v.validateExpression(ret.Value, block)...)
2024-03-16 20:12:00 +01:00
}
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil {
errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
2024-03-16 20:12:00 +01:00
}
2024-03-17 19:55:28 +01:00
if _, ok := block.Locals[dlv.Variable]; ok {
2024-03-21 19:55:05 +01:00
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
2024-03-17 19:55:28 +01:00
}
2024-03-18 21:14:28 +01:00
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
block.Locals[dlv.Variable] = local
*functionLocals = append(*functionLocals, local)
2024-03-19 12:48:06 +01:00
// TODO: check if assignment of initializer is correct
2024-03-24 14:01:23 +01:00
stmt.Value = dlv
case Statement_If:
ifS := stmt.Value.(IfStatement)
errors = append(errors, v.validateExpression(&ifS.Condition, block)...)
errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...)
if ifS.ElseBlock != nil {
errors = append(errors, v.validateBlock(ifS.ElseBlock, functionLocals)...)
}
if len(errors) != 0 {
return errors
}
if ifS.Condition.ValueType.Type != Type_Primitive || ifS.Condition.ValueType.Value.(PrimitiveType) != Primitive_Bool {
errors = append(errors, v.createError("condition must evaluate to boolean", ifS.Condition.Position))
}
stmt.Value = ifS
default:
panic("stmt not implemented")
2024-03-16 20:12:00 +01:00
}
return errors
}
func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error {
2024-03-16 20:12:00 +01:00
var errors []error
2024-03-17 19:55:28 +01:00
if block.Locals == nil {
block.Locals = make(map[string]Local)
}
2024-03-16 20:12:00 +01:00
2024-03-17 19:55:28 +01:00
for i := range block.Statements {
stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
2024-03-16 20:12:00 +01:00
}
return errors
}
func (v *Validator) validateFunction(function *ParsedFunction) []error {
2024-03-16 20:12:00 +01:00
var errors []error
2024-03-18 21:14:28 +01:00
var locals []Local
2024-03-17 19:55:28 +01:00
body := &function.Body
body.Locals = make(map[string]Local)
for _, param := range function.Parameters {
2024-03-18 21:14:28 +01:00
local := Local{Name: param.Name, Type: param.Type, IsParameter: true, Index: len(locals)}
locals = append(locals, local)
body.Locals[param.Name] = local
2024-03-17 19:55:28 +01:00
}
errors = append(errors, v.validateBlock(body, &locals)...)
2024-03-17 19:55:28 +01:00
2024-03-18 21:14:28 +01:00
function.Locals = locals
2024-03-16 20:12:00 +01:00
return errors
}
func (v *Validator) validate() []error {
2024-03-16 20:12:00 +01:00
var errors []error
for i := range v.file.Imports {
errors = append(errors, v.validateImport(&v.file.Imports[i])...)
2024-03-16 20:12:00 +01:00
}
for i := range v.file.Functions {
errors = append(errors, v.validateFunction(&v.file.Functions[i])...)
2024-03-16 20:12:00 +01:00
}
return errors
}