2024-03-16 20:12:00 +01:00
package main
import (
2024-03-28 15:48:50 +01:00
"fmt"
2024-03-20 19:26:48 +01:00
"strconv"
2024-03-16 20:12:00 +01:00
)
2024-03-20 19:26:48 +01:00
type Validator struct {
2024-03-29 15:12:37 +01:00
files [ ] * ParsedFile
allFunctions map [ string ] * ParsedFunction
2024-03-24 15:19:45 +01:00
currentBlock * Block
currentFunction * ParsedFunction
2024-03-16 20:12:00 +01:00
}
2024-03-20 19:26:48 +01:00
func isTypeExpandableTo ( from Type , to Type ) bool {
2024-03-24 15:19:45 +01:00
if from . Type != to . Type {
// cannot convert between primitive, named, array and tuple types
return false
}
if from . Type == Type_Primitive {
2024-03-20 19:26:48 +01:00
return isPrimitiveTypeExpandableTo ( from . Value . ( PrimitiveType ) , to . Value . ( PrimitiveType ) )
}
2024-03-24 15:19:45 +01:00
if from . Type == Type_Tuple {
fromT := from . Value . ( TupleType )
toT := to . Value . ( TupleType )
if len ( fromT . Types ) != len ( toT . Types ) {
return false
}
for i := 0 ; i < len ( fromT . Types ) ; i ++ {
if ! isTypeExpandableTo ( fromT . Types [ i ] , toT . Types [ i ] ) {
return false
}
}
return true
}
2024-03-20 19:26:48 +01:00
panic ( "not implemented" )
2024-03-16 20:12:00 +01:00
}
2024-03-20 19:26:48 +01:00
func isPrimitiveTypeExpandableTo ( from PrimitiveType , to PrimitiveType ) bool {
2024-03-17 19:55:28 +01:00
if from == to {
return true
}
switch from {
2024-03-19 12:19:19 +01:00
case Primitive_I8 , Primitive_U8 :
2024-03-17 19:55:28 +01:00
if to == Primitive_I16 || to == Primitive_U16 {
return true
}
fallthrough
2024-03-19 12:19:19 +01:00
case Primitive_I16 , Primitive_U16 :
2024-03-17 19:55:28 +01:00
if to == Primitive_I32 || to == Primitive_U32 {
return true
}
fallthrough
2024-03-19 12:19:19 +01:00
case Primitive_I32 , Primitive_U32 :
2024-03-17 19:55:28 +01:00
if to == Primitive_I64 || to == Primitive_U64 {
return true
}
case Primitive_F32 :
if to == Primitive_F64 {
return true
}
}
return false
}
2024-03-29 15:12:37 +01:00
func ( v * Validator ) createError ( message string , position TokenPosition ) error {
2024-03-21 19:55:05 +01:00
return CompilerError { Position : position , Message : message }
2024-03-20 19:26:48 +01:00
}
func ( v * Validator ) validateImport ( imp * Import ) [ ] error {
2024-03-28 15:48:50 +01:00
// TODO imports
2024-03-20 19:26:48 +01:00
return nil
}
2024-03-24 14:01:23 +01:00
func ( v * Validator ) getArithmeticResultType ( expr * Expression , left PrimitiveType , right PrimitiveType , operation Operation ) ( PrimitiveType , error ) {
2024-03-17 19:55:28 +01:00
if left == Primitive_Bool || right == Primitive_Bool {
2024-03-21 19:55:05 +01:00
return InvalidValue , v . createError ( "bool type cannot be used in arithmetic expressions" , expr . Position )
2024-03-17 19:55:28 +01:00
}
2024-03-20 19:26:48 +01:00
if isPrimitiveTypeExpandableTo ( left , right ) {
2024-03-17 19:55:28 +01:00
return right , nil
}
2024-03-20 19:26:48 +01:00
if isPrimitiveTypeExpandableTo ( right , left ) {
2024-03-17 19:55:28 +01:00
return left , nil
}
2024-03-28 15:48:50 +01:00
return InvalidValue , v . createError ( fmt . Sprintf ( "cannot use the types [%s, %s] in an arithmetic expression without an explicit cast" , left , right ) , expr . Position ) // TODO: include type names in error
2024-03-17 19:55:28 +01:00
}
2024-03-24 20:53:34 +01:00
func getLocal ( block * Block , variable string ) * Local {
if local , ok := block . Locals [ variable ] ; ok {
return & local
}
if block . Parent == nil {
return nil
}
return getLocal ( block . Parent , variable )
}
2024-03-24 15:19:45 +01:00
func ( v * Validator ) validatePotentiallyVoidExpression ( expr * Expression ) [ ] error {
2024-03-16 20:12:00 +01:00
var errors [ ] error
switch expr . Type {
case Expression_Assignment :
assignment := expr . Value . ( AssignmentExpression )
2024-03-24 20:53:34 +01:00
local := getLocal ( v . currentBlock , assignment . Variable )
if local == nil {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "Assignment to undeclared variable " + assignment . Variable , expr . Position ) )
2024-03-16 20:12:00 +01:00
return errors
}
2024-03-24 15:19:45 +01:00
valueErrors := v . validateExpression ( & assignment . Value )
2024-03-16 20:12:00 +01:00
if len ( valueErrors ) != 0 {
errors = append ( errors , valueErrors ... )
return errors
}
2024-03-24 21:36:34 +01:00
if ! isTypeExpandableTo ( * assignment . Value . ValueType , local . Type ) {
2024-03-28 15:48:50 +01:00
errors = append ( errors , v . createError ( fmt . Sprintf ( "cannot assign expression of type %s to variable of type %s" , * assignment . Value . ValueType , local . Type ) , expr . Position ) )
2024-03-24 21:36:34 +01:00
}
2024-03-21 19:55:05 +01:00
expr . ValueType = & local . Type
2024-03-19 10:54:21 +01:00
expr . Value = assignment
2024-03-16 20:12:00 +01:00
case Expression_Literal :
literal := expr . Value . ( LiteralExpression )
switch literal . Literal . Type {
2024-03-19 12:48:06 +01:00
case Literal_Boolean , Literal_Number :
2024-03-21 19:55:05 +01:00
expr . ValueType = & Type { Type : Type_Primitive , Value : literal . Literal . Primitive }
2024-03-16 20:12:00 +01:00
case Literal_String :
2024-03-21 19:55:05 +01:00
expr . ValueType = & STRING_TYPE
2024-03-16 20:12:00 +01:00
}
case Expression_VariableReference :
reference := expr . Value . ( VariableReferenceExpression )
2024-03-24 20:53:34 +01:00
local := getLocal ( v . currentBlock , reference . Variable )
if local == nil {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "Reference to undeclared variable " + reference . Variable , expr . Position ) )
2024-03-16 20:12:00 +01:00
return errors
}
2024-03-21 19:55:05 +01:00
expr . ValueType = & local . Type
2024-03-23 14:03:20 +01:00
case Expression_Binary :
2024-03-24 14:01:23 +01:00
binary := expr . Value . ( BinaryExpression )
2024-03-16 20:12:00 +01:00
2024-03-24 15:19:45 +01:00
errors = append ( errors , v . validateExpression ( & binary . Left ) ... )
errors = append ( errors , v . validateExpression ( & binary . Right ) ... )
2024-03-16 20:12:00 +01:00
if len ( errors ) != 0 {
return errors
}
2024-03-24 14:01:23 +01:00
if isBooleanOperation ( binary . Operation ) {
if binary . Left . ValueType . Type != Type_Primitive || binary . Right . ValueType . Type != Type_Primitive {
errors = append ( errors , v . createError ( "cannot compare non-primitive types" , expr . Position ) )
return errors
}
leftType := binary . Left . ValueType . Value . ( PrimitiveType )
rightType := binary . Right . ValueType . Value . ( PrimitiveType )
var result PrimitiveType = InvalidValue
if isPrimitiveTypeExpandableTo ( leftType , rightType ) {
result = leftType
}
if isPrimitiveTypeExpandableTo ( rightType , leftType ) {
result = leftType
}
if result == InvalidValue {
2024-03-28 15:48:50 +01:00
errors = append ( errors , v . createError ( fmt . Sprintf ( "cannot compare the types %s and %s without an explicit cast" , leftType , rightType ) , expr . Position ) )
2024-03-24 14:01:23 +01:00
return errors
}
binary . ResultType = & Type { Type : Type_Primitive , Value : result }
expr . ValueType = & Type { Type : Type_Primitive , Value : Primitive_Bool }
2024-03-17 19:55:28 +01:00
}
2024-03-24 14:01:23 +01:00
if isArithmeticOperation ( binary . Operation ) {
if binary . Left . ValueType . Type != Type_Primitive || binary . Right . ValueType . Type != Type_Primitive {
2024-03-28 15:48:50 +01:00
errors = append ( errors , v . createError ( "both sides of an arithmetic expression must evaluate to a primitive type" , expr . Position ) )
2024-03-24 14:01:23 +01:00
return errors
}
leftType := binary . Left . ValueType . Value . ( PrimitiveType )
rightType := binary . Right . ValueType . Value . ( PrimitiveType )
result , err := v . getArithmeticResultType ( expr , leftType , rightType , binary . Operation )
if err != nil {
errors = append ( errors , err )
return errors
}
binary . ResultType = & Type { Type : Type_Primitive , Value : result }
expr . ValueType = & Type { Type : Type_Primitive , Value : result }
2024-03-17 19:55:28 +01:00
}
2024-03-24 14:01:23 +01:00
expr . Value = binary
2024-03-16 20:12:00 +01:00
case Expression_Tuple :
tuple := expr . Value . ( TupleExpression )
var types [ ] Type
2024-03-17 19:55:28 +01:00
for i := range tuple . Members {
member := & tuple . Members [ i ]
2024-03-24 15:19:45 +01:00
memberErrors := v . validateExpression ( member )
2024-03-16 20:12:00 +01:00
if len ( memberErrors ) != 0 {
errors = append ( errors , memberErrors ... )
continue
}
2024-03-21 19:55:05 +01:00
types = append ( types , * member . ValueType )
2024-03-16 20:12:00 +01:00
}
if len ( errors ) != 0 {
return errors
}
2024-03-21 19:55:05 +01:00
expr . ValueType = & Type { Type : Type_Tuple , Value : TupleType { Types : types } }
2024-03-19 10:54:21 +01:00
expr . Value = tuple
2024-03-20 19:26:48 +01:00
case Expression_FunctionCall :
fc := expr . Value . ( FunctionCallExpression )
2024-03-29 15:12:37 +01:00
calledFunc , ok := v . allFunctions [ fc . Function ]
if ! ok {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "call to undefined function '" + fc . Function + "'" , expr . Position ) )
2024-03-20 19:26:48 +01:00
return errors
}
if fc . Parameters != nil {
2024-03-24 15:19:45 +01:00
paramsErrors := v . validateExpression ( fc . Parameters )
2024-03-21 20:37:21 +01:00
if len ( paramsErrors ) != 0 {
2024-03-21 19:55:05 +01:00
errors = append ( errors , paramsErrors ... )
return errors
}
2024-03-20 19:26:48 +01:00
params := [ ] Expression { * fc . Parameters }
if fc . Parameters . Type == Expression_Tuple {
params = fc . Parameters . Value . ( TupleExpression ) . Members
}
if len ( params ) != len ( calledFunc . Parameters ) {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "wrong number of arguments in function call: expected " + strconv . Itoa ( len ( calledFunc . Parameters ) ) + ", got " + strconv . Itoa ( len ( params ) ) , expr . Position ) )
2024-03-20 19:26:48 +01:00
}
for i := 0 ; i < min ( len ( params ) , len ( calledFunc . Parameters ) ) ; i ++ {
typeGiven := params [ i ]
typeExpected := calledFunc . Parameters [ i ]
2024-03-21 19:55:05 +01:00
if ! isTypeExpandableTo ( * typeGiven . ValueType , typeExpected . Type ) {
errors = append ( errors , v . createError ( "invalid type for parameter " + strconv . Itoa ( i ) , expr . Position ) )
2024-03-20 19:26:48 +01:00
}
}
}
// TODO: get function and validate using return type
expr . ValueType = calledFunc . ReturnType
expr . Value = fc
case Expression_Negate :
neg := expr . Value . ( NegateExpression )
2024-03-24 15:19:45 +01:00
valErrors := v . validateExpression ( & neg . Value )
2024-03-21 20:37:21 +01:00
if len ( valErrors ) != 0 {
2024-03-21 19:55:05 +01:00
errors = append ( errors , valErrors ... )
return errors
}
2024-03-20 19:26:48 +01:00
if neg . Value . ValueType . Type != Type_Primitive {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "cannot negate non-number types" , expr . Position ) )
2024-03-20 19:26:48 +01:00
}
expr . ValueType = neg . Value . ValueType
expr . Value = neg
2024-03-24 14:01:23 +01:00
default :
panic ( "expr not implemented" )
2024-03-16 20:12:00 +01:00
}
return errors
}
2024-03-24 15:19:45 +01:00
func ( v * Validator ) validateExpression ( expr * Expression ) [ ] error {
errors := v . validatePotentiallyVoidExpression ( expr )
2024-03-21 20:37:21 +01:00
if len ( errors ) != 0 {
return errors
}
2024-03-21 19:55:05 +01:00
if expr . ValueType == nil {
errors = append ( errors , v . createError ( "expression must not evaluate to void" , expr . Position ) )
}
return errors
}
2024-03-24 15:19:45 +01:00
func ( v * Validator ) validateStatement ( stmt * Statement , functionLocals * [ ] Local ) [ ] error {
2024-03-16 20:12:00 +01:00
var errors [ ] error
switch stmt . Type {
case Statement_Expression :
expression := stmt . Value . ( ExpressionStatement )
2024-03-24 15:19:45 +01:00
errors = append ( errors , v . validatePotentiallyVoidExpression ( & expression . Expression ) ... )
2024-03-21 19:55:05 +01:00
stmt . Value = expression
2024-03-16 20:12:00 +01:00
case Statement_Block :
block := stmt . Value . ( BlockStatement )
2024-03-24 20:53:34 +01:00
errors = append ( errors , v . validateBlock ( block . Block , functionLocals ) ... )
2024-03-21 19:55:05 +01:00
stmt . Value = block
2024-03-16 20:12:00 +01:00
case Statement_Return :
ret := stmt . Value . ( ReturnStatement )
if ret . Value != nil {
2024-03-24 15:19:45 +01:00
if v . currentFunction . ReturnType == nil {
errors = append ( errors , v . createError ( "cannot return value from void function" , stmt . Position ) )
return errors
}
errors = append ( errors , v . validateExpression ( ret . Value ) ... )
if len ( errors ) != 0 {
return errors
}
if ! isTypeExpandableTo ( * ret . Value . ValueType , * v . currentFunction . ReturnType ) {
2024-03-28 18:20:52 +01:00
errors = append ( errors , v . createError ( fmt . Sprintf ( "cannot return value of type %s from function returning %s" , * ret . Value . ValueType , * v . currentFunction . ReturnType ) , ret . Value . Position ) )
2024-03-24 15:19:45 +01:00
}
} else if v . currentFunction . ReturnType != nil {
errors = append ( errors , v . createError ( "missing return value" , stmt . Position ) )
2024-03-16 20:12:00 +01:00
}
2024-03-24 20:53:34 +01:00
stmt . Value = ret
2024-03-16 20:12:00 +01:00
case Statement_DeclareLocalVariable :
dlv := stmt . Value . ( DeclareLocalVariableStatement )
if dlv . Initializer != nil {
2024-03-24 15:19:45 +01:00
errors = append ( errors , v . validateExpression ( dlv . Initializer ) ... )
2024-03-24 21:36:34 +01:00
if errors != nil {
return errors
}
if ! isTypeExpandableTo ( * dlv . Initializer . ValueType , dlv . VariableType ) {
2024-03-28 15:48:50 +01:00
errors = append ( errors , v . createError ( fmt . Sprintf ( "cannot assign expression of type %s to variable of type %s" , * dlv . Initializer . ValueType , dlv . VariableType ) , stmt . Position ) )
2024-03-24 21:36:34 +01:00
}
2024-03-16 20:12:00 +01:00
}
2024-03-24 21:36:34 +01:00
if getLocal ( v . currentBlock , dlv . Variable ) != nil {
2024-03-21 19:55:05 +01:00
errors = append ( errors , v . createError ( "redeclaration of variable '" + dlv . Variable + "'" , stmt . Position ) )
2024-03-17 19:55:28 +01:00
}
2024-03-18 21:14:28 +01:00
local := Local { Name : dlv . Variable , Type : dlv . VariableType , IsParameter : false , Index : len ( * functionLocals ) }
2024-03-24 15:19:45 +01:00
v . currentBlock . Locals [ dlv . Variable ] = local
2024-03-18 21:14:28 +01:00
* functionLocals = append ( * functionLocals , local )
2024-03-19 12:48:06 +01:00
2024-03-24 14:01:23 +01:00
stmt . Value = dlv
case Statement_If :
ifS := stmt . Value . ( IfStatement )
2024-03-24 15:19:45 +01:00
errors = append ( errors , v . validateExpression ( & ifS . Condition ) ... )
2024-03-24 20:53:34 +01:00
errors = append ( errors , v . validateBlock ( ifS . ConditionalBlock , functionLocals ) ... )
2024-03-24 14:01:23 +01:00
if ifS . ElseBlock != nil {
errors = append ( errors , v . validateBlock ( ifS . ElseBlock , functionLocals ) ... )
}
if len ( errors ) != 0 {
return errors
}
if ifS . Condition . ValueType . Type != Type_Primitive || ifS . Condition . ValueType . Value . ( PrimitiveType ) != Primitive_Bool {
errors = append ( errors , v . createError ( "condition must evaluate to boolean" , ifS . Condition . Position ) )
}
stmt . Value = ifS
default :
panic ( "stmt not implemented" )
2024-03-16 20:12:00 +01:00
}
return errors
}
2024-03-20 19:26:48 +01:00
func ( v * Validator ) validateBlock ( block * Block , functionLocals * [ ] Local ) [ ] error {
2024-03-16 20:12:00 +01:00
var errors [ ] error
2024-03-17 19:55:28 +01:00
if block . Locals == nil {
block . Locals = make ( map [ string ] Local )
}
2024-03-16 20:12:00 +01:00
2024-03-17 19:55:28 +01:00
for i := range block . Statements {
2024-03-24 15:19:45 +01:00
v . currentBlock = block
2024-03-17 19:55:28 +01:00
stmt := & block . Statements [ i ]
2024-03-24 15:19:45 +01:00
errors = append ( errors , v . validateStatement ( stmt , functionLocals ) ... )
2024-03-16 20:12:00 +01:00
}
return errors
}
2024-03-20 19:26:48 +01:00
func ( v * Validator ) validateFunction ( function * ParsedFunction ) [ ] error {
2024-03-16 20:12:00 +01:00
var errors [ ] error
2024-03-18 21:14:28 +01:00
var locals [ ] Local
2024-03-24 15:19:45 +01:00
v . currentFunction = function
2024-03-24 20:53:34 +01:00
body := function . Body
2024-03-17 19:55:28 +01:00
body . Locals = make ( map [ string ] Local )
for _ , param := range function . Parameters {
2024-03-18 21:14:28 +01:00
local := Local { Name : param . Name , Type : param . Type , IsParameter : true , Index : len ( locals ) }
locals = append ( locals , local )
body . Locals [ param . Name ] = local
2024-03-17 19:55:28 +01:00
}
2024-03-20 19:26:48 +01:00
errors = append ( errors , v . validateBlock ( body , & locals ) ... )
2024-03-17 19:55:28 +01:00
2024-03-24 21:36:34 +01:00
// TODO: validate that function returns return value
2024-03-18 21:14:28 +01:00
function . Locals = locals
2024-03-16 20:12:00 +01:00
return errors
}
2024-03-20 19:26:48 +01:00
func ( v * Validator ) validate ( ) [ ] error {
2024-03-16 20:12:00 +01:00
var errors [ ] error
2024-03-29 15:12:37 +01:00
v . allFunctions = make ( map [ string ] * ParsedFunction )
for _ , file := range v . files {
for i := range file . Functions {
function := & file . Functions [ i ]
fullFunctionName := function . Name
if file . Module != "" {
fullFunctionName = file . Module + "." + fullFunctionName
}
function . FullName = fullFunctionName
if _ , exists := v . allFunctions [ fullFunctionName ] ; exists {
errors = append ( errors , v . createError ( "duplicate function " + fullFunctionName , function . ReturnType . Position ) )
}
v . allFunctions [ fullFunctionName ] = function
}
2024-03-16 20:12:00 +01:00
}
2024-03-29 15:12:37 +01:00
for _ , file := range v . files {
for i := range file . Imports {
errors = append ( errors , v . validateImport ( & file . Imports [ i ] ) ... )
}
for i := range file . Functions {
errors = append ( errors , v . validateFunction ( & file . Functions [ i ] ) ... )
}
2024-03-16 20:12:00 +01:00
}
return errors
}