More validation

This commit is contained in:
MrLetsplay 2024-03-17 19:55:28 +01:00
parent a8d5c7d479
commit ad32195fa2
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
6 changed files with 210 additions and 18 deletions

4
example/add.lang Normal file
View File

@ -0,0 +1,4 @@
u8 add(u8 a, u8 b) {
u8 c = b;
return a + b;
}

View File

@ -218,7 +218,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
var numberType PrimitiveType = InvalidValue
var rawNumber string = token
for i, name := range NumberTypeNames {
for i, name := range PRIMITIVE_TYPE_NAMES {
if strings.HasSuffix(token, name) {
numberType = PrimitiveType(i)
rawNumber = token[:len(token)-len(name)]

View File

@ -49,4 +49,6 @@ func main() {
log.Fatalln(err)
}
}
log.Printf("Validated:\n%+#v\n\n", parsed)
}

View File

@ -20,10 +20,6 @@ type Type struct {
Value any
}
const STRING_TYPE_NAME = "string"
var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME}
type NamedType struct {
TypeName string
}
@ -211,6 +207,19 @@ func (p *Parser) expectSeparator(separators ...Separator) (Separator, error) {
return *sep, nil
}
func (p *Parser) tryOperator(operators ...Operator) (*Operator, error) {
pCopy := p.copy()
operator := pCopy.nextToken()
if operator == nil || operator.Type != Type_Operator || !slices.Contains(operators, operator.Value.(Operator)) {
return nil, nil
}
*p = pCopy
sep := operator.Value.(Operator)
return &sep, nil
}
func (p *Parser) expectIdentifier() (string, error) {
identifier := p.nextToken()
if identifier == nil || identifier.Type != Type_Identifier {
@ -256,6 +265,13 @@ func (p *Parser) tryType() (*Type, error) {
if tok.Type == Type_Identifier {
// TODO: array type
index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string))
if index != -1 {
*p = pCopy
return &Type{Type: Type_Primitive, Value: PrimitiveType(index)}, nil
}
*p = pCopy
return &Type{Type: Type_Named, Value: tok.Value}, nil
}
@ -403,11 +419,81 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
}
func (p *Parser) tryMultiplicativeExpression() (*Expression, error) {
return p.tryUnaryExpression()
left, err := p.tryUnaryExpression()
if err != nil {
return nil, err
}
op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo)
if err != nil {
return nil, err
}
if op == nil {
return left, nil
}
right, err := p.tryUnaryExpression()
if err != nil {
return nil, err
}
if right == nil {
return nil, p.error("expected expression")
}
var operation ArithmeticOperation
switch *op {
case Operator_Multiply:
operation = Arithmetic_Mul
case Operator_Divide:
operation = Arithmetic_Div
case Operator_Modulo:
fallthrough
default:
operation = Arithmetic_Mod
}
if *op == Operator_Plus {
operation = Arithmetic_Add
} else {
operation = Arithmetic_Sub
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
}
func (p *Parser) tryAdditiveExpression() (*Expression, error) {
return p.tryMultiplicativeExpression()
left, err := p.tryMultiplicativeExpression()
if err != nil {
return nil, err
}
op, err := p.tryOperator(Operator_Plus, Operator_Minus)
if err != nil {
return nil, err
}
if op == nil {
return left, nil
}
right, err := p.tryMultiplicativeExpression()
if err != nil {
return nil, err
}
if right == nil {
return nil, p.error("expected expression")
}
var operation ArithmeticOperation
if *op == Operator_Plus {
operation = Arithmetic_Add
} else {
operation = Arithmetic_Sub
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
}
func (p *Parser) tryArithmeticExpression() (*Expression, error) {

View File

@ -1,6 +1,8 @@
package main
import (
"errors"
"slices"
"strconv"
)
@ -16,7 +18,7 @@ type Lang_U64 uint64
type Lang_Bool bool
var NumberTypeNames = [...]string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"}
var PRIMITIVE_TYPE_NAMES = []string{"i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "bool"}
type PrimitiveType uint32
@ -34,6 +36,10 @@ const (
Primitive_Bool
)
const STRING_TYPE_NAME = "string"
var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME}
const InvalidValue = 0xEEEEEE // Magic value
type CompilerError struct {
@ -86,3 +92,12 @@ func getBits(primitiveType PrimitiveType) int {
panic("Passed an invalid type (" + strconv.FormatUint(uint64(primitiveType), 10) + ") to getBits()")
}
}
func getPrimitiveTypeByName(name string) (PrimitiveType, error) {
idx := slices.Index(PRIMITIVE_TYPE_NAMES, name)
if idx == -1 {
return InvalidValue, errors.New("not a primitive type name")
}
return PrimitiveType(idx), nil
}

View File

@ -2,6 +2,7 @@ package main
import (
"errors"
"log"
)
func createError(message string) error {
@ -14,6 +15,59 @@ func validateImport(imp *Import) []error {
return nil
}
func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
if from == to {
return true
}
switch from {
case Primitive_I8:
case Primitive_U8:
if to == Primitive_I16 || to == Primitive_U16 {
return true
}
fallthrough
case Primitive_I16:
case Primitive_U16:
if to == Primitive_I32 || to == Primitive_U32 {
return true
}
fallthrough
case Primitive_I32:
case Primitive_U32:
if to == Primitive_I64 || to == Primitive_U64 {
return true
}
case Primitive_F32:
if to == Primitive_F64 {
return true
}
}
return false
}
func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
if left == Primitive_Bool || right == Primitive_Bool {
return InvalidValue, createError("bool type cannot be used in arithmetic expressions")
}
if isTypeExpandableTo(left, right) {
return right, nil
}
if isTypeExpandableTo(right, left) {
return left, nil
}
// TODO: boolean expressions etc.
return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
}
func validateExpression(expr *Expression, block *Block) []error {
var errors []error
@ -66,12 +120,28 @@ func validateExpression(expr *Expression, block *Block) []error {
}
// TODO: validate types compatible and determine result type
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type"))
return errors
}
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
rightType := arithmethic.Left.ValueType.Value.(PrimitiveType)
result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation)
if err != nil {
errors = append(errors, err)
return errors
}
expr.ValueType = Type{Type: Type_Primitive, Value: result}
case Expression_Tuple:
tuple := expr.Value.(TupleExpression)
var types []Type
for _, member := range tuple.Members {
memberErrors := validateExpression(&member, block)
for i := range tuple.Members {
member := &tuple.Members[i]
memberErrors := validateExpression(member, block)
if len(memberErrors) != 0 {
errors = append(errors, memberErrors...)
continue
@ -111,6 +181,10 @@ func validateStatement(stmt *Statement, block *Block) []error {
errors = append(errors, validateExpression(dlv.Initializer, block)...)
}
if _, ok := block.Locals[dlv.Variable]; ok {
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'"))
}
block.Locals[dlv.Variable] = Local{Name: dlv.Variable, Type: dlv.VariableType}
}
@ -120,10 +194,13 @@ func validateStatement(stmt *Statement, block *Block) []error {
func validateBlock(block *Block) []error {
var errors []error
block.Locals = make(map[string]Local)
if block.Locals == nil {
block.Locals = make(map[string]Local)
}
for _, stmt := range block.Statements {
errors = append(errors, validateStatement(&stmt, block)...)
for i := range block.Statements {
stmt := &block.Statements[i]
errors = append(errors, validateStatement(stmt, block)...)
}
return errors
@ -132,7 +209,15 @@ func validateBlock(block *Block) []error {
func validateFunction(function *ParsedFunction) []error {
var errors []error
errors = append(errors, validateBlock(&function.Body)...)
body := &function.Body
body.Locals = make(map[string]Local)
for _, param := range function.Parameters {
body.Locals[param.Name] = Local(param)
}
errors = append(errors, validateBlock(body)...)
log.Printf("%+#v", body)
return errors
}
@ -140,12 +225,12 @@ func validateFunction(function *ParsedFunction) []error {
func validator(file *ParsedFile) []error {
var errors []error
for _, imp := range file.Imports {
errors = append(errors, validateImport(&imp)...)
for i := range file.Imports {
errors = append(errors, validateImport(&file.Imports[i])...)
}
for _, function := range file.Functions {
errors = append(errors, validateFunction(&function)...)
for i := range file.Functions {
errors = append(errors, validateFunction(&file.Functions[i])...)
}
return errors