Improve compiler & validator, Update lexer

This commit is contained in:
MrLetsplay 2024-03-20 19:26:48 +01:00
parent 8519f38bab
commit 63ccacba2d
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
6 changed files with 282 additions and 98 deletions

View File

@ -2,6 +2,7 @@ package main
import (
"errors"
"log"
"strconv"
)
@ -136,9 +137,11 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
}
func compileExpressionWAT(expr Expression, block Block) (string, error) {
var err error
switch expr.Type {
case Expression_Assignment:
// TODO
case Expression_Literal:
lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type {
@ -165,6 +168,8 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
case Expression_Arithmetic:
arith := expr.Value.(ArithmeticExpression)
log.Printf("%+#v", arith)
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
exprType := expr.ValueType.Value.(PrimitiveType)
@ -212,9 +217,51 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil
case Expression_Tuple:
tuple := expr.Value.(TupleExpression)
wat := ""
for _, member := range tuple.Members {
memberWAT, err := compileExpressionWAT(member, block)
if err != nil {
return "", err
}
wat += memberWAT
}
return wat, nil
case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression)
wat := ""
if fc.Parameters != nil {
wat, err = compileExpressionWAT(*fc.Parameters, block)
if err != nil {
return "", err
}
}
return wat + "call $" + fc.Function + "\n", nil
case Expression_Negate:
neg := expr.Value.(NegateExpression)
exprType := expr.ValueType.Value.(PrimitiveType)
wat, err := compileExpressionWAT(neg.Value, block)
if err != nil {
return "", err
}
watType := getPrimitiveWATType(exprType)
if isSignedInt(exprType) || isUnsignedInt(exprType) {
return watType + ".const 0\n" + wat + watType + ".sub\n", nil
}
if isFloatingPoint(exprType) {
return watType + ".neg\n", nil
}
}
return "", nil
panic("expr not implemented")
}
func compileStatementWAT(stmt Statement, block Block) (string, error) {
@ -286,7 +333,14 @@ func compileFunctionWAT(function ParsedFunction) (string, error) {
}
// TODO: tuples
funcWAT += "\t(result " + getWATType(function.ReturnType) + ")\n"
returnTypes := []Type{function.ReturnType}
if function.ReturnType.Type == Type_Tuple {
returnTypes = function.ReturnType.Value.(TupleType).Types
}
for _, t := range returnTypes {
funcWAT += "\t(result " + getWATType(t) + ")\n"
}
for _, local := range function.Locals {
if local.IsParameter {

View File

@ -1,3 +1,7 @@
u64 add(u8 a, u64 b) {
return a * a + b * b;
u64 add(u8 a, u8 b) {
return add(a - 1u8, a);
}
(u8, u8) doNothing(u8 a, u8 b) {
return a, b;
}

View File

@ -212,8 +212,7 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
}
runes := []rune(token)
startsWithMinus := runes[0] == '-'
if startsWithMinus || unicode.IsDigit([]rune(token)[0]) {
if unicode.IsDigit([]rune(token)[0]) {
// TODO: hexadecimal/binary/octal constants
var numberType PrimitiveType = InvalidValue
@ -230,10 +229,8 @@ func (l *Lexer) parseToken(token string) (*LexToken, error) {
if numberType == InvalidValue {
if containsDot {
numberType = Primitive_F64
} else if startsWithMinus {
numberType = Primitive_I64
} else {
numberType = Primitive_U64
numberType = Primitive_I64
}
}

View File

@ -39,7 +39,8 @@ func main() {
log.Printf("Parsed:\n%+#v\n\n", parsed)
errors := validator(parsed)
validator := Validator{file: parsed}
errors := validator.validate()
if len(errors) != 0 {
for _, err = range errors {
if c, ok := err.(CompilerError); ok {

164
parser.go
View File

@ -72,6 +72,8 @@ const (
Expression_VariableReference
Expression_Arithmetic
Expression_Tuple
Expression_FunctionCall
Expression_Negate
)
type Expression struct {
@ -113,6 +115,15 @@ type TupleExpression struct {
Members []Expression
}
type FunctionCallExpression struct {
Function string
Parameters *Expression
}
type NegateExpression struct {
Value Expression
}
type Local struct {
Name string
Type Type
@ -413,8 +424,51 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
}
}
if token.Type == Type_Operator {
op := token.Value.(Operator)
if op == Operator_Minus || op == Operator_Plus {
pCopy.nextToken()
expr, err := pCopy.tryUnaryExpression()
if err != nil {
return nil, err
}
if expr == nil {
return nil, nil
}
if op == Operator_Minus {
expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}}
}
return expr, nil
}
}
if token.Type == Type_Identifier {
pCopy.nextToken()
next, err := pCopy.trySeparator(Separator_OpenParen)
if err != nil {
return nil, err
}
if next != nil {
// Function call
params, err := pCopy.tryTupleExpression()
if err != nil {
return nil, err
}
_, err = pCopy.expectSeparator(Separator_CloseParen)
if err != nil {
return nil, err
}
*p = pCopy
return &Expression{Type: Expression_FunctionCall, Value: FunctionCallExpression{Function: token.Value.(string), Parameters: params}}, nil
}
*p = pCopy
return &Expression{Type: Expression_VariableReference, Value: VariableReferenceExpression{Variable: token.Value.(string)}}, nil
}
@ -428,41 +482,47 @@ func (p *Parser) tryMultiplicativeExpression() (*Expression, error) {
return nil, err
}
op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo)
if err != nil {
return nil, err
if left == nil {
return nil, nil
}
if op == nil {
return left, nil
}
for {
op, err := p.tryOperator(Operator_Multiply, Operator_Divide, Operator_Modulo)
if err != nil {
return nil, err
}
right, err := p.tryUnaryExpression()
if err != nil {
return nil, err
}
if op == nil {
return left, nil
}
if right == nil {
return nil, p.error("expected expression")
}
right, err := p.tryUnaryExpression()
if err != nil {
return nil, err
}
var operation ArithmeticOperation
switch *op {
case Operator_Multiply:
operation = Arithmetic_Mul
case Operator_Divide:
operation = Arithmetic_Div
case Operator_Plus:
operation = Arithmetic_Add
case Operator_Minus:
operation = Arithmetic_Sub
case Operator_Modulo:
fallthrough
default:
operation = Arithmetic_Mod
}
if right == nil {
return nil, p.error("expected expression")
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
var operation ArithmeticOperation
switch *op {
case Operator_Multiply:
operation = Arithmetic_Mul
case Operator_Divide:
operation = Arithmetic_Div
case Operator_Plus:
operation = Arithmetic_Add
case Operator_Minus:
operation = Arithmetic_Sub
case Operator_Modulo:
fallthrough
default:
operation = Arithmetic_Mod
}
left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}
}
}
func (p *Parser) tryAdditiveExpression() (*Expression, error) {
@ -471,32 +531,38 @@ func (p *Parser) tryAdditiveExpression() (*Expression, error) {
return nil, err
}
op, err := p.tryOperator(Operator_Plus, Operator_Minus)
if err != nil {
return nil, err
if left == nil {
return nil, nil
}
if op == nil {
return left, nil
}
for {
op, err := p.tryOperator(Operator_Plus, Operator_Minus)
if err != nil {
return nil, err
}
right, err := p.tryMultiplicativeExpression()
if err != nil {
return nil, err
}
if op == nil {
return left, nil
}
if right == nil {
return nil, p.error("expected expression")
}
right, err := p.tryMultiplicativeExpression()
if err != nil {
return nil, err
}
var operation ArithmeticOperation
if *op == Operator_Plus {
operation = Arithmetic_Add
} else {
operation = Arithmetic_Sub
}
if right == nil {
return nil, p.error("expected expression")
}
return &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}, nil
var operation ArithmeticOperation
if *op == Operator_Plus {
operation = Arithmetic_Add
} else {
operation = Arithmetic_Sub
}
left = &Expression{Type: Expression_Arithmetic, Value: ArithmeticExpression{Operation: operation, Left: *left, Right: *right}}
}
}
func (p *Parser) tryArithmeticExpression() (*Expression, error) {

View File

@ -2,19 +2,22 @@ package main
import (
"errors"
"strconv"
)
func createError(message string) error {
// TODO: pass token and get actual token position
return errors.New(message)
type Validator struct {
file *ParsedFile
}
func validateImport(imp *Import) []error {
// TODO
return nil
func isTypeExpandableTo(from Type, to Type) bool {
if from.Type == Type_Primitive && to.Type == Type_Primitive {
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
}
panic("not implemented")
}
func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
if from == to {
return true
}
@ -46,25 +49,35 @@ func isTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
return false
}
func getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
func (v *Validator) createError(message string) error {
// TODO: pass token and get actual token position
return errors.New(message)
}
func (v *Validator) validateImport(imp *Import) []error {
// TODO
return nil
}
func (v *Validator) getArithmeticResultType(left PrimitiveType, right PrimitiveType, operation ArithmeticOperation) (PrimitiveType, error) {
if left == Primitive_Bool || right == Primitive_Bool {
return InvalidValue, createError("bool type cannot be used in arithmetic expressions")
return InvalidValue, v.createError("bool type cannot be used in arithmetic expressions")
}
if isTypeExpandableTo(left, right) {
if isPrimitiveTypeExpandableTo(left, right) {
return right, nil
}
if isTypeExpandableTo(right, left) {
if isPrimitiveTypeExpandableTo(right, left) {
return left, nil
}
// TODO: boolean expressions etc.
return InvalidValue, createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast") // TODO: include type names in error
}
func validateExpression(expr *Expression, block *Block) []error {
func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
var errors []error
switch expr.Type {
@ -73,11 +86,11 @@ func validateExpression(expr *Expression, block *Block) []error {
var local Local
var ok bool
if local, ok = block.Locals[assignment.Variable]; !ok {
errors = append(errors, createError("Assignment to undeclared variable "+assignment.Variable))
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable))
return errors
}
valueErrors := validateExpression(&assignment.Value, block)
valueErrors := v.validateExpression(&assignment.Value, block)
if len(valueErrors) != 0 {
errors = append(errors, valueErrors...)
return errors
@ -100,31 +113,29 @@ func validateExpression(expr *Expression, block *Block) []error {
var local Local
var ok bool
if local, ok = block.Locals[reference.Variable]; !ok {
errors = append(errors, createError("Reference to undeclared variable "+reference.Variable))
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable))
return errors
}
expr.ValueType = local.Type
expr.Value = reference
case Expression_Arithmetic:
arithmethic := expr.Value.(ArithmeticExpression)
errors = append(errors, validateExpression(&arithmethic.Left, block)...)
errors = append(errors, validateExpression(&arithmethic.Right, block)...)
errors = append(errors, v.validateExpression(&arithmethic.Left, block)...)
errors = append(errors, v.validateExpression(&arithmethic.Right, block)...)
if len(errors) != 0 {
return errors
}
// TODO: validate types compatible and determine result type
if arithmethic.Left.ValueType.Type != Type_Primitive || arithmethic.Right.ValueType.Type != Type_Primitive {
errors = append(errors, createError("both sides of an arithmetic expression must be a primitive type"))
errors = append(errors, v.createError("both sides of an arithmetic expression must be a primitive type"))
return errors
}
leftType := arithmethic.Left.ValueType.Value.(PrimitiveType)
rightType := arithmethic.Right.ValueType.Value.(PrimitiveType)
result, err := getArithmeticResultType(leftType, rightType, arithmethic.Operation)
result, err := v.getArithmeticResultType(leftType, rightType, arithmethic.Operation)
if err != nil {
errors = append(errors, err)
return errors
@ -139,7 +150,7 @@ func validateExpression(expr *Expression, block *Block) []error {
for i := range tuple.Members {
member := &tuple.Members[i]
memberErrors := validateExpression(member, block)
memberErrors := v.validateExpression(member, block)
if len(memberErrors) != 0 {
errors = append(errors, memberErrors...)
continue
@ -154,12 +165,63 @@ func validateExpression(expr *Expression, block *Block) []error {
expr.ValueType = Type{Type: Type_Tuple, Value: TupleType{Types: types}}
expr.Value = tuple
case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression)
var calledFunc *ParsedFunction = nil
for _, f := range v.file.Functions {
if f.Name == fc.Function {
calledFunc = &f
break
}
}
if calledFunc == nil {
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'"))
return errors
}
if fc.Parameters != nil {
errors = append(errors, v.validateExpression(fc.Parameters, block)...)
params := []Expression{*fc.Parameters}
if fc.Parameters.Type == Expression_Tuple {
params = fc.Parameters.Value.(TupleExpression).Members
}
if len(params) != len(calledFunc.Parameters) {
errors = append(errors, v.createError("wrong number of arguments in function call: expected "+strconv.Itoa(len(calledFunc.Parameters))+", got "+strconv.Itoa(len(params))))
}
for i := 0; i < min(len(params), len(calledFunc.Parameters)); i++ {
typeGiven := params[i]
typeExpected := calledFunc.Parameters[i]
if !isTypeExpandableTo(typeGiven.ValueType, typeExpected.Type) {
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i)))
}
}
}
// TODO: get function and validate using return type
expr.ValueType = calledFunc.ReturnType
expr.Value = fc
case Expression_Negate:
neg := expr.Value.(NegateExpression)
errors = append(errors, v.validateExpression(&neg.Value, block)...)
if neg.Value.ValueType.Type != Type_Primitive {
errors = append(errors, v.createError("cannot negate non-number types"))
}
expr.ValueType = neg.Value.ValueType
expr.Value = neg
}
return errors
}
func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
var errors []error
// TODO: support references to variables in parent block
@ -167,25 +229,25 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [
switch stmt.Type {
case Statement_Expression:
expression := stmt.Value.(ExpressionStatement)
errors = append(errors, validateExpression(&expression.Expression, block)...)
errors = append(errors, v.validateExpression(&expression.Expression, block)...)
*stmt = Statement{Type: Statement_Expression, Value: expression}
case Statement_Block:
block := stmt.Value.(BlockStatement)
errors = append(errors, validateBlock(&block.Block, functionLocals)...)
errors = append(errors, v.validateBlock(&block.Block, functionLocals)...)
*stmt = Statement{Type: Statement_Block, Value: block}
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
if ret.Value != nil {
errors = append(errors, validateExpression(ret.Value, block)...)
errors = append(errors, v.validateExpression(ret.Value, block)...)
}
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil {
errors = append(errors, validateExpression(dlv.Initializer, block)...)
errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
}
if _, ok := block.Locals[dlv.Variable]; ok {
errors = append(errors, createError("redeclaration of variable '"+dlv.Variable+"'"))
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'"))
}
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
@ -198,7 +260,7 @@ func validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) [
return errors
}
func validateBlock(block *Block, functionLocals *[]Local) []error {
func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error {
var errors []error
if block.Locals == nil {
@ -207,13 +269,13 @@ func validateBlock(block *Block, functionLocals *[]Local) []error {
for i := range block.Statements {
stmt := &block.Statements[i]
errors = append(errors, validateStatement(stmt, block, functionLocals)...)
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
}
return errors
}
func validateFunction(function *ParsedFunction) []error {
func (v *Validator) validateFunction(function *ParsedFunction) []error {
var errors []error
var locals []Local
@ -226,22 +288,22 @@ func validateFunction(function *ParsedFunction) []error {
body.Locals[param.Name] = local
}
errors = append(errors, validateBlock(body, &locals)...)
errors = append(errors, v.validateBlock(body, &locals)...)
function.Locals = locals
return errors
}
func validator(file *ParsedFile) []error {
func (v *Validator) validate() []error {
var errors []error
for i := range file.Imports {
errors = append(errors, validateImport(&file.Imports[i])...)
for i := range v.file.Imports {
errors = append(errors, v.validateImport(&v.file.Imports[i])...)
}
for i := range file.Functions {
errors = append(errors, validateFunction(&file.Functions[i])...)
for i := range v.file.Functions {
errors = append(errors, v.validateFunction(&v.file.Functions[i])...)
}
return errors