Improve compiler & validator, Update lexer

This commit is contained in:
MrLetsplay 2024-03-20 19:26:48 +01:00
parent 8519f38bab
commit 63ccacba2d
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
6 changed files with 282 additions and 98 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"errors" "errors"
"log"
"strconv" "strconv"
) )
@ -136,9 +137,11 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
} }
func compileExpressionWAT(expr Expression, block Block) (string, error) { func compileExpressionWAT(expr Expression, block Block) (string, error) {
var err error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
// TODO
case Expression_Literal: case Expression_Literal:
lit := expr.Value.(LiteralExpression) lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type { switch lit.Literal.Type {
@ -165,6 +168,8 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
case Expression_Arithmetic: case Expression_Arithmetic:
arith := expr.Value.(ArithmeticExpression) arith := expr.Value.(ArithmeticExpression)
log.Printf("%+#v", arith)
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
@ -212,9 +217,51 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil
case Expression_Tuple: case Expression_Tuple:
tuple := expr.Value.(TupleExpression)
wat := ""
for _, member := range tuple.Members {
memberWAT, err := compileExpressionWAT(member, block)
if err != nil {
return "", err
}
wat += memberWAT
}
return wat, nil
case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression)
wat := ""
if fc.Parameters != nil {
wat, err = compileExpressionWAT(*fc.Parameters, block)
if err != nil {
return "", err
}
}
return wat + "call $" + fc.Function + "\n", nil
case Expression_Negate:
neg := expr.Value.(NegateExpression)
exprType := expr.ValueType.Value.(PrimitiveType)
wat, err := compileExpressionWAT(neg.Value, block)
if err != nil {
return "", err
}
watType := getPrimitiveWATType(exprType)
if isSignedInt(exprType) || isUnsignedInt(exprType) {
return watType + ".const 0\n" + wat + watType + ".sub\n", nil
}
if isFloatingPoint(exprType) {
return watType + ".neg\n", nil
}
} }
return "", nil panic("expr not implemented")
} }
func compileStatementWAT(stmt Statement, block Block) (string, error) { func compileStatementWAT(stmt Statement, block Block) (string, error) {
@ -286,7 +333,14 @@ func compileFunctionWAT(function ParsedFunction) (string, error) {
} }
// TODO: tuples // TODO: tuples
funcWAT += "\t(result " + getWATType(function.ReturnType) + ")\n" returnTypes := []Type{function.ReturnType}
if function.ReturnType.Type == Type_Tuple {
returnTypes = function.ReturnType.Value.(TupleType).Types
}
for _, t := range returnTypes {
funcWAT += "\t(result " + getWATType(t) + ")\n"
}
for _, local := range function.Locals { for _, local := range function.Locals {
if local.IsParameter { if local.IsParameter {

View File

@ -1,3 +1,7 @@
u64 add(u8 a, u64 b) { u64 add(u8 a, u8 b) {
return a * a + b * b; return add(a - 1u8, a);
}
(u8, u8) doNothing(u8 a, u8 b) {
return a, b;
} }

View File

@ -212,8 +212,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
} }
runes := []rune(token) runes := []rune(token)
startsWithMinus := runes[0] == '-' if unicode.IsDigit([]rune(token)[0]) {
if startsWithMinus || unicode.IsDigit([]rune(token)[0]) {
// TODO: hexadecimal/binary/octal constants // TODO: hexadecimal/binary/octal constants
var numberType PrimitiveType = InvalidValue var numberType PrimitiveType = InvalidValue
@ -230,10 +229,8 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
if numberType == InvalidValue { if numberType == InvalidValue {
if containsDot { if containsDot {
numberType = Primitive_F64 numberType = Primitive_F64
} else if startsWithMinus {
numberType = Primitive_I64
} else { } else {
numberType = Primitive_U64 numberType = Primitive_I64
} }
} }

View File

@ -39,7 +39,8 @@ func main() {
log.Printf("Parsed:\n%+#v\n\n", parsed) log.Printf("Parsed:\n%+#v\n\n", parsed)
errors := validator(parsed) validator := Validator{file: parsed}
errors := validator.validate()
if len(errors) != 0 { if len(errors) != 0 {
for _, err = range errors { for _, err = range errors {
if c, ok := err.(CompilerError); ok { if c, ok := err.(CompilerError); ok {

164
parser.go
View File

@ -72,6 +72,8 @@ const (
Expression_VariableReference Expression_VariableReference
Expression_Arithmetic Expression_Arithmetic
Expression_Tuple Expression_Tuple
Expression_FunctionCall
Expression_Negate
) )
type Expression struct { type Expression struct {
@ -113,6 +115,15 @@ type TupleExpression struct {
Members []Expression Members []Expression
} }
type FunctionCallExpression struct {
Function string
Parameters *Expression
}
type NegateExpression struct {
Value Expression
}
type Local struct { type Local struct {
Name string Name string
Type Type Type Type
@ -413,8 +424,51 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
} }
} }
if token.Type == Type_Operator {
op := token.Value.(Operator)
if op == Operator_Minus || op == Operator_Plus {
pCopy.nextToken()
expr, err := pCopy.tryUnaryExpression()
if err != nil {
return nil, err
}
if expr == nil {
return nil, nil
}
if op == Operator_Minus {
expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}}
}
return expr, nil
}
}
if token.Type == Type_Identifier { if token.Type == Type_Identifier {
pCopy.nextToken() pCopy.nextToken()
next, err := pCopy.trySeparator(Separator_OpenParen)
if err != nil {
return nil, err
}
if next != nil {
// Function call
params, err := pCopy.tryTupleExpression()
if err != nil {
return nil, err
}
_, err = pCopy.expectSeparator(Separator_CloseParen)
if err != nil {
return nil, err
}
*p = pCopy
return &Expression{Type: Expression_FunctionCall, Value: FunctionCallExpression{Function: token.Value.(string), Parameters: params}}, nil
}
*p = pCopy *p = pCopy
return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}}, nil return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}}, nil
} }
@ -428,41 +482,47 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) {
return nil, err return nil, err
} }
op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo) if left == nil {
if err != nil { return nil, nil
return nil, err
} }
if op == nil { for {
return left, nil op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo)
} if err != nil {
return nil, err
}
right, err := p.tryUnaryExpression() if op == nil {
if err != nil { return left, nil
return nil, err }
}
if right == nil { right, err := p.tryUnaryExpression()
return nil, p.error("expected expression") if err != nil {
} return nil, err
}
var operation ArithmeticOperation if right == nil {
switch *op { return nil, p.error("expected expression")
case Operator_Multiply: }
operation = Arithmetic_Mul
case Operator_Divide:
operation = Arithmetic_Div
case Operator_Plus:
operation = Arithmetic_Add
case Operator_Minus:
operation = Arithmetic_Sub
case Operator_Modulo:
fallthrough
default:
operation = Arithmetic_Mod
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil var operation ArithmeticOperation
switch *op {
case Operator_Multiply:
operation = Arithmetic_Mul
case Operator_Divide:
operation = Arithmetic_Div
case Operator_Plus:
operation = Arithmetic_Add
case Operator_Minus:
operation = Arithmetic_Sub
case Operator_Modulo:
fallthrough
default:
operation = Arithmetic_Mod
}
left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}
}
} }
func (p *Parser) tryAdditiveExpression() (*Expression, error) { func (p *Parser) tryAdditiveExpression() (*Expression, error) {
@ -471,32 +531,38 @@ func (p *Parser) tryAdditiveExpression() (*Expression, error) {
return nil, err return nil, err
} }
op, err := p.tryOperator(Operator_Plus, Operator_Minus) if left == nil {
if err != nil { return nil, nil
return nil, err
} }
if op == nil { for {
return left, nil op, err := p.tryOperator(Operator_Plus, Operator_Minus)
} if err != nil {
return nil, err
}
right, err := p.tryMultiplicativeExpression() if op == nil {
if err != nil { return left, nil
return nil, err }
}
if right == nil { right, err := p.tryMultiplicativeExpression()
return nil, p.error("expected expression") if err != nil {
} return nil, err
}
var operation ArithmeticOperation if right == nil {
if *op == Operator_Plus { return nil, p.error("expected expression")
operation = Arithmetic_Add }
} else {
operation = Arithmetic_Sub
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil var operation ArithmeticOperation
if *op == Operator_Plus {
operation = Arithmetic_Add
} else {
operation = Arithmetic_Sub
}
left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}
}
} }
func (p *Parser) tryArithmeticExpression() (*Expression, error) { func (p *Parser) tryArithmeticExpression() (*Expression, error) {

View File

@ -2,19 +2,22 @@ package main
import ( import (
"errors" "errors"
"strconv"
) )
func createError(message string) error { type Validator struct {
// TODO: pass token and get actual token position file *ParsedFile
return errors.New(message)
} }
func validateImport(imp *Import) []error { func isTypeExpandableTo(from Type, to Type) bool {
// TODO if from.Type == Type_Primitive && to.Type == Type_Primitive {
return nil return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
}
panic("not implemented")
} }
func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
if from == to { if from == to {
return true return true
} }
@ -46,25 +49,35 @@ func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
return false return false
} }
func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) { func (v *Validator) createError(message string) error {
// TODO: pass token and get actual token position
return errors.New(message)
}
func (v *Validator) validateImport(imp *Import) []error {
// TODO
return nil
}
func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
if left == Primitive_Bool || right == Primitive_Bool { if left == Primitive_Bool || right == Primitive_Bool {
return InvalidValue, createError("bool type cannot be used in arithmetic expressions") return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions")
} }
if isTypeExpandableTo(left, right) { if isPrimitiveTypeExpandableTo(left, right) {
return right, nil return right, nil
} }
if isTypeExpandableTo(right, left) { if isPrimitiveTypeExpandableTo(right, left) {
return left, nil return left, nil
} }
// TODO: boolean expressions etc. // TODO: boolean expressions etc.
return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
} }
func validateExpression(expr *Expression, block *Block) []error { func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
var errors []error var errors []error
switch expr.Type { switch expr.Type {
@ -73,11 +86,11 @@ func validateExpression(expr *Expression, block *Block) []error {
var local Local var local Local
var ok bool var ok bool
if local, ok = block.Locals[assignment.Variable]; !ok { if local, ok = block.Locals[assignment.Variable]; !ok {
errors = append(errors, createError("Assignment to undeclared variable "+assignment.Variable)) errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable))
return errors return errors
} }
valueErrors := validateExpression(&assignment.Value, block) valueErrors := v.validateExpression(&assignment.Value, block)
if len(valueErrors) != 0 { if len(valueErrors) != 0 {
errors = append(errors, valueErrors...) errors = append(errors, valueErrors...)
return errors return errors
@ -100,31 +113,29 @@ func validateExpression(expr *Expression, block *Block) []error {
var local Local var local Local
var ok bool var ok bool
if local, ok = block.Locals[reference.Variable]; !ok { if local, ok = block.Locals[reference.Variable]; !ok {
errors = append(errors, createError("Reference to undeclared variable "+reference.Variable)) errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable))
return errors return errors
} }
expr.ValueType = local.Type expr.ValueType = local.Type
expr.Value = reference
case Expression_Arithmetic: case Expression_Arithmetic:
arithmethic := expr.Value.(ArithmeticExpression) arithmethic := expr.Value.(ArithmeticExpression)
errors = append(errors, validateExpression(&arithmethic.Left, block)...) errors = append(errors, v.validateExpression(&arithmethic.Left, block)...)
errors = append(errors, validateExpression(&arithmethic.Right, block)...) errors = append(errors, v.validateExpression(&arithmethic.Right, block)...)
if len(errors) != 0 { if len(errors) != 0 {
return errors return errors
} }
// TODO: validate types compatible and determine result type
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive { if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type")) errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type"))
return errors return errors
} }
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType) leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
rightType := arithmethic.Right.ValueType.Value.(PrimitiveType) rightType := arithmethic.Right.ValueType.Value.(PrimitiveType)
result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation) result, err := v.getArithmeticResultType(leftType, rightType, arithmethic.Operation)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
return errors return errors
@ -139,7 +150,7 @@ func validateExpression(expr *Expression, block *Block) []error {
for i := range tuple.Members { for i := range tuple.Members {
member := &tuple.Members[i] member := &tuple.Members[i]
memberErrors := validateExpression(member, block) memberErrors := v.validateExpression(member, block)
if len(memberErrors) != 0 { if len(memberErrors) != 0 {
errors = append(errors, memberErrors...) errors = append(errors, memberErrors...)
continue continue
@ -154,12 +165,63 @@ func validateExpression(expr *Expression, block *Block) []error {
expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}} expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}}
expr.Value = tuple expr.Value = tuple
case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression)
var calledFunc *ParsedFunction = nil
for _, f := range v.file.Functions {
if f.Name == fc.Function {
calledFunc = &f
break
}
}
if calledFunc == nil {
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'"))
return errors
}
if fc.Parameters != nil {
errors = append(errors, v.validateExpression(fc.Parameters, block)...)
params := []Expression{*fc.Parameters}
if fc.Parameters.Type == Expression_Tuple {
params = fc.Parameters.Value.(TupleExpression).Members
}
if len(params) != len(calledFunc.Parameters) {
errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params))))
}
for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ {
typeGiven := params[i]
typeExpected := calledFunc.Parameters[i]
if !isTypeExpandableTo(typeGiven.ValueType, typeExpected.Type) {
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i)))
}
}
}
// TODO: get function and validate using return type
expr.ValueType = calledFunc.ReturnType
expr.Value = fc
case Expression_Negate:
neg := expr.Value.(NegateExpression)
errors = append(errors, v.validateExpression(&neg.Value, block)...)
if neg.Value.ValueType.Type != Type_Primitive {
errors = append(errors, v.createError("cannot negate non-number types"))
}
expr.ValueType = neg.Value.ValueType
expr.Value = neg
} }
return errors return errors
} }
func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
var errors []error var errors []error
// TODO: support references to variables in parent block // TODO: support references to variables in parent block
@ -167,25 +229,25 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expression := stmt.Value.(ExpressionStatement) expression := stmt.Value.(ExpressionStatement)
errors = append(errors, validateExpression(&expression.Expression, block)...) errors = append(errors, v.validateExpression(&expression.Expression, block)...)
*stmt = Statement{Type: Statement_Expression, Value: expression} *stmt = Statement{Type: Statement_Expression, Value: expression}
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
errors = append(errors, validateBlock(&block.Block, functionLocals)...) errors = append(errors, v.validateBlock(&block.Block, functionLocals)...)
*stmt = Statement{Type: Statement_Block, Value: block} *stmt = Statement{Type: Statement_Block, Value: block}
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
if ret.Value != nil { if ret.Value != nil {
errors = append(errors, validateExpression(ret.Value, block)...) errors = append(errors, v.validateExpression(ret.Value, block)...)
} }
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil { if dlv.Initializer != nil {
errors = append(errors, validateExpression(dlv.Initializer, block)...) errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
} }
if _, ok := block.Locals[dlv.Variable]; ok { if _, ok := block.Locals[dlv.Variable]; ok {
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'")) errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'"))
} }
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
@ -198,7 +260,7 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [
return errors return errors
} }
func validateBlock(block *Block, functionLocals *[]Local) []error { func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error {
var errors []error var errors []error
if block.Locals == nil { if block.Locals == nil {
@ -207,13 +269,13 @@ func validateBlock(block *Block, functionLocals *[]Local) []error {
for i := range block.Statements { for i := range block.Statements {
stmt := &block.Statements[i] stmt := &block.Statements[i]
errors = append(errors, validateStatement(stmt, block, functionLocals)...) errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
} }
return errors return errors
} }
func validateFunction(function *ParsedFunction) []error { func (v *Validator) validateFunction(function *ParsedFunction) []error {
var errors []error var errors []error
var locals []Local var locals []Local
@ -226,22 +288,22 @@ func validateFunction(function *ParsedFunction) []error {
body.Locals[param.Name] = local body.Locals[param.Name] = local
} }
errors = append(errors, validateBlock(body, &locals)...) errors = append(errors, v.validateBlock(body, &locals)...)
function.Locals = locals function.Locals = locals
return errors return errors
} }
func validator(file *ParsedFile) []error { func (v *Validator) validate() []error {
var errors []error var errors []error
for i := range file.Imports { for i := range v.file.Imports {
errors = append(errors, validateImport(&file.Imports[i])...) errors = append(errors, v.validateImport(&v.file.Imports[i])...)
} }
for i := range file.Functions { for i := range v.file.Functions {
errors = append(errors, validateFunction(&file.Functions[i])...) errors = append(errors, v.validateFunction(&v.file.Functions[i])...)
} }
return errors return errors