From ad32195fa2219ae96c83e530c02db61c98e6f840 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Sun, 17 Mar 2024 19:55:28 +0100 Subject: [PATCH] More validation --- example/add.lang | 4 ++ lexer.go | 2 +- main.go | 2 + parser.go | 98 ++++++++++++++++++++++++++++++++++++++++--- types.go | 17 +++++++- validator.go | 105 ++++++++++++++++++++++++++++++++++++++++++----- 6 files changed, 210 insertions(+), 18 deletions(-) create mode 100644 example/add.lang diff --git a/example/add.lang b/example/add.lang new file mode 100644 index 0000000..49d67c9 --- /dev/null +++ b/example/add.lang @@ -0,0 +1,4 @@ +u8 add(u8 a, u8 b) { + u8 c = b; + return a + b; +} diff --git a/lexer.go b/lexer.go index 22470e7..9f91f74 100644 --- a/lexer.go +++ b/lexer.go @@ -218,7 +218,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) { var numberType PrimitiveType = InvalidValue var rawNumber string = token - for i, name := range NumberTypeNames { + for i, name := range PRIMITIVE_TYPE_NAMES { if strings.HasSuffix(token, name) { numberType = PrimitiveType(i) rawNumber = token[:len(token)-len(name)] diff --git a/main.go b/main.go index f9f4505..634eb28 100644 --- a/main.go +++ b/main.go @@ -49,4 +49,6 @@ func main() { log.Fatalln(err) } } + + log.Printf("Validated:\n%+#v\n\n", parsed) } diff --git a/parser.go b/parser.go index 8f8c46e..95b146d 100644 --- a/parser.go +++ b/parser.go @@ -20,10 +20,6 @@ type Type struct { Value any } -const STRING_TYPE_NAME = "string" - -var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME} - type NamedType struct { TypeName string } @@ -211,6 +207,19 @@ func (p *Parser) expectSeparator(separators ...Separator) (Separator, error) { return *sep, nil } +func (p *Parser) tryOperator(operators ...Operator) (*Operator, error) { + pCopy := p.copy() + + operator := pCopy.nextToken() + if operator == nil || operator.Type != Type_Operator || !slices.Contains(operators, operator.Value.(Operator)) { + return nil, nil + } + + *p = pCopy + sep := operator.Value.(Operator) + return &sep, nil +} + func (p *Parser) expectIdentifier() (string, error) { identifier := p.nextToken() if identifier == nil || identifier.Type != Type_Identifier { @@ -256,6 +265,13 @@ func (p *Parser) tryType() (*Type, error) { if tok.Type == Type_Identifier { // TODO: array type + + index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string)) + if index != -1 { + *p = pCopy + return &Type{Type: Type_Primitive, Value: PrimitiveType(index)}, nil + } + *p = pCopy return &Type{Type: Type_Named, Value: tok.Value}, nil } @@ -403,11 +419,81 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { } func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { - return p.tryUnaryExpression() + left, err := p.tryUnaryExpression() + if err != nil { + return nil, err + } + + op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo) + if err != nil { + return nil, err + } + + if op == nil { + return left, nil + } + + right, err := p.tryUnaryExpression() + if err != nil { + return nil, err + } + + if right == nil { + return nil, p.error("expected expression") + } + + var operation ArithmeticOperation + switch *op { + case Operator_Multiply: + operation = Arithmetic_Mul + case Operator_Divide: + operation = Arithmetic_Div + case Operator_Modulo: + fallthrough + default: + operation = Arithmetic_Mod + } + if *op == Operator_Plus { + operation = Arithmetic_Add + } else { + operation = Arithmetic_Sub + } + + return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil } func (p *Parser) tryAdditiveExpression() (*Expression, error) { - return p.tryMultiplicativeExpression() + left, err := p.tryMultiplicativeExpression() + if err != nil { + return nil, err + } + + op, err := p.tryOperator(Operator_Plus, Operator_Minus) + if err != nil { + return nil, err + } + + if op == nil { + return left, nil + } + + right, err := p.tryMultiplicativeExpression() + if err != nil { + return nil, err + } + + if right == nil { + return nil, p.error("expected expression") + } + + var operation ArithmeticOperation + if *op == Operator_Plus { + operation = Arithmetic_Add + } else { + operation = Arithmetic_Sub + } + + return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil } func (p *Parser) tryArithmeticExpression() (*Expression, error) { diff --git a/types.go b/types.go index 9a24e0f..2d485b2 100644 --- a/types.go +++ b/types.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "slices" "strconv" ) @@ -16,7 +18,7 @@ type Lang_U64 uint64 type Lang_Bool bool -var NumberTypeNames = [...]string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"} +var PRIMITIVE_TYPE_NAMES = []string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"} type PrimitiveType uint32 @@ -34,6 +36,10 @@ const ( Primitive_Bool ) +const STRING_TYPE_NAME = "string" + +var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME} + const InvalidValue = 0xEEEEEE // Magic value type CompilerError struct { @@ -86,3 +92,12 @@ func getBits(primitiveType PrimitiveType) int { panic("Passed an invalid type (" + strconv.FormatUint(uint64(primitiveType), 10) + ") to getBits()") } } + +func getPrimitiveTypeByName(name string) (PrimitiveType, error) { + idx := slices.Index(PRIMITIVE_TYPE_NAMES, name) + if idx == -1 { + return InvalidValue, errors.New("not a primitive type name") + } + + return PrimitiveType(idx), nil +} diff --git a/validator.go b/validator.go index 6bd412d..1c60c17 100644 --- a/validator.go +++ b/validator.go @@ -2,6 +2,7 @@ package main import ( "errors" + "log" ) func createError(message string) error { @@ -14,6 +15,59 @@ func validateImport(imp *Import) []error { return nil } +func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { + if from == to { + return true + } + + switch from { + case Primitive_I8: + case Primitive_U8: + if to == Primitive_I16 || to == Primitive_U16 { + return true + } + + fallthrough + case Primitive_I16: + case Primitive_U16: + if to == Primitive_I32 || to == Primitive_U32 { + return true + } + + fallthrough + case Primitive_I32: + case Primitive_U32: + if to == Primitive_I64 || to == Primitive_U64 { + return true + } + + case Primitive_F32: + if to == Primitive_F64 { + return true + } + } + + return false +} + +func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { + if left == Primitive_Bool || right == Primitive_Bool { + return InvalidValue, createError("bool type cannot be used in arithmetic expressions") + } + + if isTypeExpandableTo(left, right) { + return right, nil + } + + if isTypeExpandableTo(right, left) { + return left, nil + } + + // TODO: boolean expressions etc. + + return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error +} + func validateExpression(expr *Expression, block *Block) []error { var errors []error @@ -66,12 +120,28 @@ func validateExpression(expr *Expression, block *Block) []error { } // TODO: validate types compatible and determine result type + if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { + errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type")) + return errors + } + + leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) + rightType := arithmethic.Left.ValueType.Value.(PrimitiveType) + result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation) + if err != nil { + errors = append(errors, err) + return errors + } + + expr.ValueType = Type{Type: Type_Primitive, Value: result} case Expression_Tuple: tuple := expr.Value.(TupleExpression) var types []Type - for _, member := range tuple.Members { - memberErrors := validateExpression(&member, block) + for i := range tuple.Members { + member := &tuple.Members[i] + + memberErrors := validateExpression(member, block) if len(memberErrors) != 0 { errors = append(errors, memberErrors...) continue @@ -111,6 +181,10 @@ func validateStatement(stmt *Statement, block *Block) []error { errors = append(errors, validateExpression(dlv.Initializer, block)...) } + if _, ok := block.Locals[dlv.Variable]; ok { + errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'")) + } + block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType} } @@ -120,10 +194,13 @@ func validateStatement(stmt *Statement, block *Block) []error { func validateBlock(block *Block) []error { var errors []error - block.Locals = make(map[string]Local) + if block.Locals == nil { + block.Locals = make(map[string]Local) + } - for _, stmt := range block.Statements { - errors = append(errors, validateStatement(&stmt, block)...) + for i := range block.Statements { + stmt := &block.Statements[i] + errors = append(errors, validateStatement(stmt, block)...) } return errors @@ -132,7 +209,15 @@ func validateBlock(block *Block) []error { func validateFunction(function *ParsedFunction) []error { var errors []error - errors = append(errors, validateBlock(&function.Body)...) + body := &function.Body + body.Locals = make(map[string]Local) + for _, param := range function.Parameters { + body.Locals[param.Name] = Local(param) + } + + errors = append(errors, validateBlock(body)...) + + log.Printf("%+#v", body) return errors } @@ -140,12 +225,12 @@ func validateFunction(function *ParsedFunction) []error { func validator(file *ParsedFile) []error { var errors []error - for _, imp := range file.Imports { - errors = append(errors, validateImport(&imp)...) + for i := range file.Imports { + errors = append(errors, validateImport(&file.Imports[i])...) } - for _, function := range file.Functions { - errors = append(errors, validateFunction(&function)...) + for i := range file.Functions { + errors = append(errors, validateFunction(&file.Functions[i])...) } return errors