More validation
This commit is contained in:
parent
a8d5c7d479
commit
ad32195fa2
4
example/add.lang
Normal file
4
example/add.lang
Normal file
@ -0,0 +1,4 @@
|
||||
u8 add(u8 a, u8 b) {
|
||||
u8 c = b;
|
||||
return a + b;
|
||||
}
|
2
lexer.go
2
lexer.go
@ -218,7 +218,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
|
||||
|
||||
var numberType PrimitiveType = InvalidValue
|
||||
var rawNumber string = token
|
||||
for i, name := range NumberTypeNames {
|
||||
for i, name := range PRIMITIVE_TYPE_NAMES {
|
||||
if strings.HasSuffix(token, name) {
|
||||
numberType = PrimitiveType(i)
|
||||
rawNumber = token[:len(token)-len(name)]
|
||||
|
2
main.go
2
main.go
@ -49,4 +49,6 @@ func main() {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Validated:\n%+#v\n\n", parsed)
|
||||
}
|
||||
|
98
parser.go
98
parser.go
@ -20,10 +20,6 @@ type Type struct {
|
||||
Value any
|
||||
}
|
||||
|
||||
const STRING_TYPE_NAME = "string"
|
||||
|
||||
var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME}
|
||||
|
||||
type NamedType struct {
|
||||
TypeName string
|
||||
}
|
||||
@ -211,6 +207,19 @@ func (p *Parser) expectSeparator(separators ...Separator) (Separator, error) {
|
||||
return *sep, nil
|
||||
}
|
||||
|
||||
func (p *Parser) tryOperator(operators ...Operator) (*Operator, error) {
|
||||
pCopy := p.copy()
|
||||
|
||||
operator := pCopy.nextToken()
|
||||
if operator == nil || operator.Type != Type_Operator || !slices.Contains(operators, operator.Value.(Operator)) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
*p = pCopy
|
||||
sep := operator.Value.(Operator)
|
||||
return &sep, nil
|
||||
}
|
||||
|
||||
func (p *Parser) expectIdentifier() (string, error) {
|
||||
identifier := p.nextToken()
|
||||
if identifier == nil || identifier.Type != Type_Identifier {
|
||||
@ -256,6 +265,13 @@ func (p *Parser) tryType() (*Type, error) {
|
||||
|
||||
if tok.Type == Type_Identifier {
|
||||
// TODO: array type
|
||||
|
||||
index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string))
|
||||
if index != -1 {
|
||||
*p = pCopy
|
||||
return &Type{Type: Type_Primitive, Value: PrimitiveType(index)}, nil
|
||||
}
|
||||
|
||||
*p = pCopy
|
||||
return &Type{Type: Type_Named, Value: tok.Value}, nil
|
||||
}
|
||||
@ -403,11 +419,81 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
|
||||
}
|
||||
|
||||
func (p *Parser) tryMultiplicativeExpression() (*Expression, error) {
|
||||
return p.tryUnaryExpression()
|
||||
left, err := p.tryUnaryExpression()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op == nil {
|
||||
return left, nil
|
||||
}
|
||||
|
||||
right, err := p.tryUnaryExpression()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if right == nil {
|
||||
return nil, p.error("expected expression")
|
||||
}
|
||||
|
||||
var operation ArithmeticOperation
|
||||
switch *op {
|
||||
case Operator_Multiply:
|
||||
operation = Arithmetic_Mul
|
||||
case Operator_Divide:
|
||||
operation = Arithmetic_Div
|
||||
case Operator_Modulo:
|
||||
fallthrough
|
||||
default:
|
||||
operation = Arithmetic_Mod
|
||||
}
|
||||
if *op == Operator_Plus {
|
||||
operation = Arithmetic_Add
|
||||
} else {
|
||||
operation = Arithmetic_Sub
|
||||
}
|
||||
|
||||
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
|
||||
}
|
||||
|
||||
func (p *Parser) tryAdditiveExpression() (*Expression, error) {
|
||||
return p.tryMultiplicativeExpression()
|
||||
left, err := p.tryMultiplicativeExpression()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
op, err := p.tryOperator(Operator_Plus, Operator_Minus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op == nil {
|
||||
return left, nil
|
||||
}
|
||||
|
||||
right, err := p.tryMultiplicativeExpression()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if right == nil {
|
||||
return nil, p.error("expected expression")
|
||||
}
|
||||
|
||||
var operation ArithmeticOperation
|
||||
if *op == Operator_Plus {
|
||||
operation = Arithmetic_Add
|
||||
} else {
|
||||
operation = Arithmetic_Sub
|
||||
}
|
||||
|
||||
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
|
||||
}
|
||||
|
||||
func (p *Parser) tryArithmeticExpression() (*Expression, error) {
|
||||
|
17
types.go
17
types.go
@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@ -16,7 +18,7 @@ type Lang_U64 uint64
|
||||
|
||||
type Lang_Bool bool
|
||||
|
||||
var NumberTypeNames = [...]string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"}
|
||||
var PRIMITIVE_TYPE_NAMES = []string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"}
|
||||
|
||||
type PrimitiveType uint32
|
||||
|
||||
@ -34,6 +36,10 @@ const (
|
||||
Primitive_Bool
|
||||
)
|
||||
|
||||
const STRING_TYPE_NAME = "string"
|
||||
|
||||
var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME}
|
||||
|
||||
const InvalidValue = 0xEEEEEE // Magic value
|
||||
|
||||
type CompilerError struct {
|
||||
@ -86,3 +92,12 @@ func getBits(primitiveType PrimitiveType) int {
|
||||
panic("Passed an invalid type (" + strconv.FormatUint(uint64(primitiveType), 10) + ") to getBits()")
|
||||
}
|
||||
}
|
||||
|
||||
func getPrimitiveTypeByName(name string) (PrimitiveType, error) {
|
||||
idx := slices.Index(PRIMITIVE_TYPE_NAMES, name)
|
||||
if idx == -1 {
|
||||
return InvalidValue, errors.New("not a primitive type name")
|
||||
}
|
||||
|
||||
return PrimitiveType(idx), nil
|
||||
}
|
||||
|
105
validator.go
105
validator.go
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
func createError(message string) error {
|
||||
@ -14,6 +15,59 @@ func validateImport(imp *Import) []error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
|
||||
if from == to {
|
||||
return true
|
||||
}
|
||||
|
||||
switch from {
|
||||
case Primitive_I8:
|
||||
case Primitive_U8:
|
||||
if to == Primitive_I16 || to == Primitive_U16 {
|
||||
return true
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case Primitive_I16:
|
||||
case Primitive_U16:
|
||||
if to == Primitive_I32 || to == Primitive_U32 {
|
||||
return true
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case Primitive_I32:
|
||||
case Primitive_U32:
|
||||
if to == Primitive_I64 || to == Primitive_U64 {
|
||||
return true
|
||||
}
|
||||
|
||||
case Primitive_F32:
|
||||
if to == Primitive_F64 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
|
||||
if left == Primitive_Bool || right == Primitive_Bool {
|
||||
return InvalidValue, createError("bool type cannot be used in arithmetic expressions")
|
||||
}
|
||||
|
||||
if isTypeExpandableTo(left, right) {
|
||||
return right, nil
|
||||
}
|
||||
|
||||
if isTypeExpandableTo(right, left) {
|
||||
return left, nil
|
||||
}
|
||||
|
||||
// TODO: boolean expressions etc.
|
||||
|
||||
return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
|
||||
}
|
||||
|
||||
func validateExpression(expr *Expression, block *Block) []error {
|
||||
var errors []error
|
||||
|
||||
@ -66,12 +120,28 @@ func validateExpression(expr *Expression, block *Block) []error {
|
||||
}
|
||||
|
||||
// TODO: validate types compatible and determine result type
|
||||
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
|
||||
errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type"))
|
||||
return errors
|
||||
}
|
||||
|
||||
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
|
||||
rightType := arithmethic.Left.ValueType.Value.(PrimitiveType)
|
||||
result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
return errors
|
||||
}
|
||||
|
||||
expr.ValueType = Type{Type: Type_Primitive, Value: result}
|
||||
case Expression_Tuple:
|
||||
tuple := expr.Value.(TupleExpression)
|
||||
|
||||
var types []Type
|
||||
for _, member := range tuple.Members {
|
||||
memberErrors := validateExpression(&member, block)
|
||||
for i := range tuple.Members {
|
||||
member := &tuple.Members[i]
|
||||
|
||||
memberErrors := validateExpression(member, block)
|
||||
if len(memberErrors) != 0 {
|
||||
errors = append(errors, memberErrors...)
|
||||
continue
|
||||
@ -111,6 +181,10 @@ func validateStatement(stmt *Statement, block *Block) []error {
|
||||
errors = append(errors, validateExpression(dlv.Initializer, block)...)
|
||||
}
|
||||
|
||||
if _, ok := block.Locals[dlv.Variable]; ok {
|
||||
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'"))
|
||||
}
|
||||
|
||||
block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType}
|
||||
}
|
||||
|
||||
@ -120,10 +194,13 @@ func validateStatement(stmt *Statement, block *Block) []error {
|
||||
func validateBlock(block *Block) []error {
|
||||
var errors []error
|
||||
|
||||
block.Locals = make(map[string]Local)
|
||||
if block.Locals == nil {
|
||||
block.Locals = make(map[string]Local)
|
||||
}
|
||||
|
||||
for _, stmt := range block.Statements {
|
||||
errors = append(errors, validateStatement(&stmt, block)...)
|
||||
for i := range block.Statements {
|
||||
stmt := &block.Statements[i]
|
||||
errors = append(errors, validateStatement(stmt, block)...)
|
||||
}
|
||||
|
||||
return errors
|
||||
@ -132,7 +209,15 @@ func validateBlock(block *Block) []error {
|
||||
func validateFunction(function *ParsedFunction) []error {
|
||||
var errors []error
|
||||
|
||||
errors = append(errors, validateBlock(&function.Body)...)
|
||||
body := &function.Body
|
||||
body.Locals = make(map[string]Local)
|
||||
for _, param := range function.Parameters {
|
||||
body.Locals[param.Name] = Local(param)
|
||||
}
|
||||
|
||||
errors = append(errors, validateBlock(body)...)
|
||||
|
||||
log.Printf("%+#v", body)
|
||||
|
||||
return errors
|
||||
}
|
||||
@ -140,12 +225,12 @@ func validateFunction(function *ParsedFunction) []error {
|
||||
func validator(file *ParsedFile) []error {
|
||||
var errors []error
|
||||
|
||||
for _, imp := range file.Imports {
|
||||
errors = append(errors, validateImport(&imp)...)
|
||||
for i := range file.Imports {
|
||||
errors = append(errors, validateImport(&file.Imports[i])...)
|
||||
}
|
||||
|
||||
for _, function := range file.Functions {
|
||||
errors = append(errors, validateFunction(&function)...)
|
||||
for i := range file.Functions {
|
||||
errors = append(errors, validateFunction(&file.Functions[i])...)
|
||||
}
|
||||
|
||||
return errors
|
||||
|
Loading…
Reference in New Issue
Block a user