Compile if statements

This commit is contained in:
MrLetsplay 2024-03-24 15:19:45 +01:00
parent c0b9ee086a
commit fa63fee64d
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 156 additions and 44 deletions

View File

@ -2,7 +2,6 @@ package main
import ( import (
"errors" "errors"
"log"
"strconv" "strconv"
"unicode" "unicode"
) )
@ -181,29 +180,28 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil
case Expression_Binary: case Expression_Binary:
arith := expr.Value.(BinaryExpression) binary := expr.Value.(BinaryExpression)
log.Printf("%+#v", arith)
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
resultType := binary.ResultType.Value.(PrimitiveType)
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
watLeft, err := compileExpressionWAT(arith.Left, block) watLeft, err := compileExpressionWAT(binary.Left, block)
if err != nil { if err != nil {
return "", err return "", err
} }
castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType) castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType)
if err != nil { if err != nil {
return "", err return "", err
} }
watRight, err := compileExpressionWAT(arith.Right, block) watRight, err := compileExpressionWAT(binary.Right, block)
if err != nil { if err != nil {
return "", err return "", err
} }
castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType) castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -211,23 +209,35 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
op := "" op := ""
suffix := "" suffix := ""
if isUnsignedInt(exprType) { if isUnsignedInt(resultType) {
suffix = "u" suffix = "u"
} else { } else {
suffix = "s" suffix = "s"
} }
switch arith.Operation { switch binary.Operation {
case Operation_Add: case Operation_Add:
op = getPrimitiveWATType(exprType) + ".add\n" op = getPrimitiveWATType(resultType) + ".add\n"
case Operation_Sub: case Operation_Sub:
op = getPrimitiveWATType(exprType) + ".sub\n" op = getPrimitiveWATType(resultType) + ".sub\n"
case Operation_Mul: case Operation_Mul:
op = getPrimitiveWATType(exprType) + ".mul\n" op = getPrimitiveWATType(resultType) + ".mul\n"
case Operation_Div: case Operation_Div:
op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n" op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n"
case Operation_Mod: case Operation_Mod:
op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n" op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n"
case Operation_Greater:
op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n"
case Operation_Less:
op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n"
case Operation_GreaterEquals:
op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n"
case Operation_LessEquals:
op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n"
case Operation_NotEquals:
op = getPrimitiveWATType(resultType) + ".ne\n"
case Operation_Equals:
op = getPrimitiveWATType(resultType) + ".eq\n"
default: default:
panic("operation not implemented") panic("operation not implemented")
} }
@ -306,6 +316,8 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
return "", err return "", err
} }
// TODO: upcast to return type for non-primitive types
return wat + "return\n", nil return wat + "return\n", nil
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)
@ -319,6 +331,53 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
} }
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
case Statement_If:
ifS := stmt.Value.(IfStatement)
conditionWAT, err := compileExpressionWAT(ifS.Condition, block)
if err != nil {
return "", err
}
condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock)
if err != nil {
return "", err
}
wat := ""
if ifS.ElseBlock != nil {
wat += "block\n"
}
// condition
wat += "block\n"
wat += conditionWAT
wat += "i32.eqz\n" // logical not
wat += "br_if 0\n"
// condition is true
wat += condBlockWAT
if ifS.ElseBlock != nil {
wat += "br 1\n" // jump over else block
}
wat += "end\n"
if ifS.ElseBlock != nil {
// condition is false
elseWAT, err := compileBlockWAT(*ifS.ElseBlock)
if err != nil {
return "", err
}
wat += elseWAT
wat += "end\n"
}
return wat, nil
} }
panic("stmt not implemented") panic("stmt not implemented")

View File

@ -2,14 +2,6 @@ u64 add(u8 a, u8 b) {
return add(a - 1u8, a); return add(a - 1u8, a);
} }
u64 add2(u64 a, u64 b) {
if(a == b) {
return 0;
}
return a + b;
}
void a() { void a() {
} }
@ -17,3 +9,19 @@ void a() {
(u8, u8) doNothing(u8 a, u8 b) { (u8, u8) doNothing(u8 a, u8 b) {
return a, b; return a, b;
} }
u64 doStuff(u64 a, u64 b) {
if(a > b) {
return 1u64;
}
return 2u64;
}
u64 fib(u64 n) {
if(n <= 1u64) {
return 1u64;
}
return fib(n - 1u64) + fib(n - 2u64);
}

View File

@ -125,7 +125,7 @@ type BinaryExpression struct {
Operation Operation Operation Operation
Left Expression Left Expression
Right Expression Right Expression
ResultType *Type ResultType *Type // Type to expand the operands to before performing the operation
} }
type TupleExpression struct { type TupleExpression struct {

View File

@ -1,18 +1,45 @@
package main package main
import ( import (
"log"
"strconv" "strconv"
) )
type Validator struct { type Validator struct {
file *ParsedFile file *ParsedFile
currentBlock *Block
currentFunction *ParsedFunction
} }
func isTypeExpandableTo(from Type, to Type) bool { func isTypeExpandableTo(from Type, to Type) bool {
if from.Type == Type_Primitive && to.Type == Type_Primitive { if from.Type != to.Type {
// cannot convert between primitive, named, array and tuple types
return false
}
if from.Type == Type_Primitive {
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType)) return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
} }
if from.Type == Type_Tuple {
fromT := from.Value.(TupleType)
toT := to.Value.(TupleType)
if len(fromT.Types) != len(toT.Types) {
return false
}
for i := 0; i < len(fromT.Types); i++ {
if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) {
return false
}
}
return true
}
log.Printf("%+#v %+#v", from, to)
panic("not implemented") panic("not implemented")
} }
@ -76,7 +103,7 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error
} }
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error { func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error {
var errors []error var errors []error
switch expr.Type { switch expr.Type {
@ -84,12 +111,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
assignment := expr.Value.(AssignmentExpression) assignment := expr.Value.(AssignmentExpression)
var local Local var local Local
var ok bool var ok bool
if local, ok = block.Locals[assignment.Variable]; !ok { if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok {
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
return errors return errors
} }
valueErrors := v.validateExpression(&assignment.Value, block) valueErrors := v.validateExpression(&assignment.Value)
if len(valueErrors) != 0 { if len(valueErrors) != 0 {
errors = append(errors, valueErrors...) errors = append(errors, valueErrors...)
return errors return errors
@ -111,7 +138,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
reference := expr.Value.(VariableReferenceExpression) reference := expr.Value.(VariableReferenceExpression)
var local Local var local Local
var ok bool var ok bool
if local, ok = block.Locals[reference.Variable]; !ok { if local, ok = v.currentBlock.Locals[reference.Variable]; !ok {
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
return errors return errors
} }
@ -120,8 +147,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
case Expression_Binary: case Expression_Binary:
binary := expr.Value.(BinaryExpression) binary := expr.Value.(BinaryExpression)
errors = append(errors, v.validateExpression(&binary.Left, block)...) errors = append(errors, v.validateExpression(&binary.Left)...)
errors = append(errors, v.validateExpression(&binary.Right, block)...) errors = append(errors, v.validateExpression(&binary.Right)...)
if len(errors) != 0 { if len(errors) != 0 {
return errors return errors
@ -180,7 +207,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
for i := range tuple.Members { for i := range tuple.Members {
member := &tuple.Members[i] member := &tuple.Members[i]
memberErrors := v.validateExpression(member, block) memberErrors := v.validateExpression(member)
if len(memberErrors) != 0 { if len(memberErrors) != 0 {
errors = append(errors, memberErrors...) errors = append(errors, memberErrors...)
continue continue
@ -212,7 +239,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
} }
if fc.Parameters != nil { if fc.Parameters != nil {
paramsErrors := v.validateExpression(fc.Parameters, block) paramsErrors := v.validateExpression(fc.Parameters)
if len(paramsErrors) != 0 { if len(paramsErrors) != 0 {
errors = append(errors, paramsErrors...) errors = append(errors, paramsErrors...)
return errors return errors
@ -242,7 +269,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
case Expression_Negate: case Expression_Negate:
neg := expr.Value.(NegateExpression) neg := expr.Value.(NegateExpression)
valErrors := v.validateExpression(&neg.Value, block) valErrors := v.validateExpression(&neg.Value)
if len(valErrors) != 0 { if len(valErrors) != 0 {
errors = append(errors, valErrors...) errors = append(errors, valErrors...)
return errors return errors
@ -261,8 +288,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
return errors return errors
} }
func (v *Validator) validateExpression(expr *Expression, block *Block) []error { func (v *Validator) validateExpression(expr *Expression) []error {
errors := v.validatePotentiallyVoidExpression(expr, block) errors := v.validatePotentiallyVoidExpression(expr)
if len(errors) != 0 { if len(errors) != 0 {
return errors return errors
} }
@ -274,7 +301,7 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
return errors return errors
} }
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error { func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error {
var errors []error var errors []error
// TODO: support references to variables in parent block // TODO: support references to variables in parent block
@ -282,7 +309,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expression := stmt.Value.(ExpressionStatement) expression := stmt.Value.(ExpressionStatement)
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...) errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...)
stmt.Value = expression stmt.Value = expression
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
@ -291,20 +318,35 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
if ret.Value != nil { if ret.Value != nil {
errors = append(errors, v.validateExpression(ret.Value, block)...) if v.currentFunction.ReturnType == nil {
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
return errors
}
errors = append(errors, v.validateExpression(ret.Value)...)
if len(errors) != 0 {
return errors
}
if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) {
errors = append(errors, v.createError("expression type does not match function return type", ret.Value.Position))
}
} else if v.currentFunction.ReturnType != nil {
errors = append(errors, v.createError("missing return value", stmt.Position))
} }
case Statement_DeclareLocalVariable: case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement) dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil { if dlv.Initializer != nil {
errors = append(errors, v.validateExpression(dlv.Initializer, block)...) errors = append(errors, v.validateExpression(dlv.Initializer)...)
} }
if _, ok := block.Locals[dlv.Variable]; ok { if _, ok := v.currentBlock.Locals[dlv.Variable]; ok {
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
} }
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
block.Locals[dlv.Variable] = local v.currentBlock.Locals[dlv.Variable] = local
*functionLocals = append(*functionLocals, local) *functionLocals = append(*functionLocals, local)
// TODO: check if assignment of initializer is correct // TODO: check if assignment of initializer is correct
@ -312,7 +354,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
case Statement_If: case Statement_If:
ifS := stmt.Value.(IfStatement) ifS := stmt.Value.(IfStatement)
errors = append(errors, v.validateExpression(&ifS.Condition, block)...) errors = append(errors, v.validateExpression(&ifS.Condition)...)
errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...) errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...)
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
@ -343,8 +385,9 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
} }
for i := range block.Statements { for i := range block.Statements {
v.currentBlock = block
stmt := &block.Statements[i] stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...) errors = append(errors, v.validateStatement(stmt, functionLocals)...)
} }
return errors return errors
@ -355,6 +398,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
var locals []Local var locals []Local
v.currentFunction = function
body := &function.Body body := &function.Body
body.Locals = make(map[string]Local) body.Locals = make(map[string]Local)
for _, param := range function.Parameters { for _, param := range function.Parameters {