238 lines
5.8 KiB
Go
238 lines
5.8 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
)
|
|
|
|
func createError(message string) error {
|
|
// TODO: pass token and get actual token position
|
|
return errors.New(message)
|
|
}
|
|
|
|
func validateImport(imp *Import) []error {
|
|
// TODO
|
|
return nil
|
|
}
|
|
|
|
func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
|
|
if from == to {
|
|
return true
|
|
}
|
|
|
|
switch from {
|
|
case Primitive_I8:
|
|
case Primitive_U8:
|
|
if to == Primitive_I16 || to == Primitive_U16 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I16:
|
|
case Primitive_U16:
|
|
if to == Primitive_I32 || to == Primitive_U32 {
|
|
return true
|
|
}
|
|
|
|
fallthrough
|
|
case Primitive_I32:
|
|
case Primitive_U32:
|
|
if to == Primitive_I64 || to == Primitive_U64 {
|
|
return true
|
|
}
|
|
|
|
case Primitive_F32:
|
|
if to == Primitive_F64 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
|
|
if left == Primitive_Bool || right == Primitive_Bool {
|
|
return InvalidValue, createError("bool type cannot be used in arithmetic expressions")
|
|
}
|
|
|
|
if isTypeExpandableTo(left, right) {
|
|
return right, nil
|
|
}
|
|
|
|
if isTypeExpandableTo(right, left) {
|
|
return left, nil
|
|
}
|
|
|
|
// TODO: boolean expressions etc.
|
|
|
|
return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
|
|
}
|
|
|
|
func validateExpression(expr *Expression, block *Block) []error {
|
|
var errors []error
|
|
|
|
switch expr.Type {
|
|
case Expression_Assignment:
|
|
assignment := expr.Value.(AssignmentExpression)
|
|
var local Local
|
|
var ok bool
|
|
if local, ok = block.Locals[assignment.Variable]; !ok {
|
|
errors = append(errors, createError("Assignment to undeclared variable "+assignment.Variable))
|
|
return errors
|
|
}
|
|
|
|
valueErrors := validateExpression(&assignment.Value, block)
|
|
if len(valueErrors) != 0 {
|
|
errors = append(errors, valueErrors...)
|
|
return errors
|
|
}
|
|
|
|
// TODO: check if assignment is valid
|
|
expr.ValueType = local.Type
|
|
case Expression_Literal:
|
|
literal := expr.Value.(LiteralExpression)
|
|
|
|
switch literal.Literal.Type {
|
|
case Literal_Boolean:
|
|
case Literal_Number:
|
|
expr.ValueType = Type{Type: Type_Primitive, Value: literal.Literal.Primitive}
|
|
case Literal_String:
|
|
expr.ValueType = STRING_TYPE
|
|
}
|
|
case Expression_VariableReference:
|
|
reference := expr.Value.(VariableReferenceExpression)
|
|
var local Local
|
|
var ok bool
|
|
if local, ok = block.Locals[reference.Variable]; !ok {
|
|
errors = append(errors, createError("Reference to undeclared variable "+reference.Variable))
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = local.Type
|
|
case Expression_Arithmetic:
|
|
arithmethic := expr.Value.(ArithmeticExpression)
|
|
|
|
errors = append(errors, validateExpression(&arithmethic.Left, block)...)
|
|
errors = append(errors, validateExpression(&arithmethic.Right, block)...)
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
// TODO: validate types compatible and determine result type
|
|
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
|
|
errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type"))
|
|
return errors
|
|
}
|
|
|
|
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
|
|
rightType := arithmethic.Left.ValueType.Value.(PrimitiveType)
|
|
result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = Type{Type: Type_Primitive, Value: result}
|
|
case Expression_Tuple:
|
|
tuple := expr.Value.(TupleExpression)
|
|
|
|
var types []Type
|
|
for i := range tuple.Members {
|
|
member := &tuple.Members[i]
|
|
|
|
memberErrors := validateExpression(member, block)
|
|
if len(memberErrors) != 0 {
|
|
errors = append(errors, memberErrors...)
|
|
continue
|
|
}
|
|
|
|
types = append(types, member.ValueType)
|
|
}
|
|
|
|
if len(errors) != 0 {
|
|
return errors
|
|
}
|
|
|
|
expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}}
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func validateStatement(stmt *Statement, block *Block) []error {
|
|
var errors []error
|
|
|
|
switch stmt.Type {
|
|
case Statement_Expression:
|
|
expression := stmt.Value.(ExpressionStatement)
|
|
errors = append(errors, validateExpression(&expression.Expression, block)...)
|
|
case Statement_Block:
|
|
block := stmt.Value.(BlockStatement)
|
|
errors = append(errors, validateBlock(&block.Block)...)
|
|
case Statement_Return:
|
|
ret := stmt.Value.(ReturnStatement)
|
|
if ret.Value != nil {
|
|
errors = append(errors, validateExpression(ret.Value, block)...)
|
|
}
|
|
case Statement_DeclareLocalVariable:
|
|
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
|
if dlv.Initializer != nil {
|
|
errors = append(errors, validateExpression(dlv.Initializer, block)...)
|
|
}
|
|
|
|
if _, ok := block.Locals[dlv.Variable]; ok {
|
|
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'"))
|
|
}
|
|
|
|
block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType}
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func validateBlock(block *Block) []error {
|
|
var errors []error
|
|
|
|
if block.Locals == nil {
|
|
block.Locals = make(map[string]Local)
|
|
}
|
|
|
|
for i := range block.Statements {
|
|
stmt := &block.Statements[i]
|
|
errors = append(errors, validateStatement(stmt, block)...)
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func validateFunction(function *ParsedFunction) []error {
|
|
var errors []error
|
|
|
|
body := &function.Body
|
|
body.Locals = make(map[string]Local)
|
|
for _, param := range function.Parameters {
|
|
body.Locals[param.Name] = Local(param)
|
|
}
|
|
|
|
errors = append(errors, validateBlock(body)...)
|
|
|
|
log.Printf("%+#v", body)
|
|
|
|
return errors
|
|
}
|
|
|
|
func validator(file *ParsedFile) []error {
|
|
var errors []error
|
|
|
|
for i := range file.Imports {
|
|
errors = append(errors, validateImport(&file.Imports[i])...)
|
|
}
|
|
|
|
for i := range file.Functions {
|
|
errors = append(errors, validateFunction(&file.Functions[i])...)
|
|
}
|
|
|
|
return errors
|
|
}
|