Compile if statements

This commit is contained in:
MrLetsplay 2024-03-24 15:19:45 +01:00
parent c0b9ee086a
commit fa63fee64d
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 156 additions and 44 deletions

View File

@ -2,7 +2,6 @@ package main
import (
"errors"
"log"
"strconv"
"unicode"
)
@ -181,29 +180,28 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil
case Expression_Binary:
arith := expr.Value.(BinaryExpression)
log.Printf("%+#v", arith)
binary := expr.Value.(BinaryExpression)
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
resultType := binary.ResultType.Value.(PrimitiveType)
exprType := expr.ValueType.Value.(PrimitiveType)
watLeft, err := compileExpressionWAT(arith.Left, block)
watLeft, err := compileExpressionWAT(binary.Left, block)
if err != nil {
return "", err
}
castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType)
castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType)
if err != nil {
return "", err
}
watRight, err := compileExpressionWAT(arith.Right, block)
watRight, err := compileExpressionWAT(binary.Right, block)
if err != nil {
return "", err
}
castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType)
castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType)
if err != nil {
return "", err
}
@ -211,23 +209,35 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
op := ""
suffix := ""
if isUnsignedInt(exprType) {
if isUnsignedInt(resultType) {
suffix = "u"
} else {
suffix = "s"
}
switch arith.Operation {
switch binary.Operation {
case Operation_Add:
op = getPrimitiveWATType(exprType) + ".add\n"
op = getPrimitiveWATType(resultType) + ".add\n"
case Operation_Sub:
op = getPrimitiveWATType(exprType) + ".sub\n"
op = getPrimitiveWATType(resultType) + ".sub\n"
case Operation_Mul:
op = getPrimitiveWATType(exprType) + ".mul\n"
op = getPrimitiveWATType(resultType) + ".mul\n"
case Operation_Div:
op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n"
op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n"
case Operation_Mod:
op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n"
op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n"
case Operation_Greater:
op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n"
case Operation_Less:
op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n"
case Operation_GreaterEquals:
op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n"
case Operation_LessEquals:
op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n"
case Operation_NotEquals:
op = getPrimitiveWATType(resultType) + ".ne\n"
case Operation_Equals:
op = getPrimitiveWATType(resultType) + ".eq\n"
default:
panic("operation not implemented")
}
@ -306,6 +316,8 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
return "", err
}
// TODO: upcast to return type for non-primitive types
return wat + "return\n", nil
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)
@ -319,6 +331,53 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
}
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
case Statement_If:
ifS := stmt.Value.(IfStatement)
conditionWAT, err := compileExpressionWAT(ifS.Condition, block)
if err != nil {
return "", err
}
condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock)
if err != nil {
return "", err
}
wat := ""
if ifS.ElseBlock != nil {
wat += "block\n"
}
// condition
wat += "block\n"
wat += conditionWAT
wat += "i32.eqz\n" // logical not
wat += "br_if 0\n"
// condition is true
wat += condBlockWAT
if ifS.ElseBlock != nil {
wat += "br 1\n" // jump over else block
}
wat += "end\n"
if ifS.ElseBlock != nil {
// condition is false
elseWAT, err := compileBlockWAT(*ifS.ElseBlock)
if err != nil {
return "", err
}
wat += elseWAT
wat += "end\n"
}
return wat, nil
}
panic("stmt not implemented")

View File

@ -2,14 +2,6 @@ u64 add(u8 a, u8 b) {
return add(a - 1u8, a);
}
u64 add2(u64 a, u64 b) {
if(a == b) {
return 0;
}
return a + b;
}
void a() {
}
@ -17,3 +9,19 @@ void a() {
(u8, u8) doNothing(u8 a, u8 b) {
return a, b;
}
u64 doStuff(u64 a, u64 b) {
if(a > b) {
return 1u64;
}
return 2u64;
}
u64 fib(u64 n) {
if(n <= 1u64) {
return 1u64;
}
return fib(n - 1u64) + fib(n - 2u64);
}

View File

@ -125,7 +125,7 @@ type BinaryExpression struct {
Operation Operation
Left Expression
Right Expression
ResultType *Type
ResultType *Type // Type to expand the operands to before performing the operation
}
type TupleExpression struct {

View File

@ -1,18 +1,45 @@
package main
import (
"log"
"strconv"
)
type Validator struct {
file *ParsedFile
currentBlock *Block
currentFunction *ParsedFunction
}
func isTypeExpandableTo(from Type, to Type) bool {
if from.Type == Type_Primitive && to.Type == Type_Primitive {
if from.Type != to.Type {
// cannot convert between primitive, named, array and tuple types
return false
}
if from.Type == Type_Primitive {
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
}
if from.Type == Type_Tuple {
fromT := from.Value.(TupleType)
toT := to.Value.(TupleType)
if len(fromT.Types) != len(toT.Types) {
return false
}
for i := 0; i < len(fromT.Types); i++ {
if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) {
return false
}
}
return true
}
log.Printf("%+#v %+#v", from, to)
panic("not implemented")
}
@ -76,7 +103,7 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error
}
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error {
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error {
var errors []error
switch expr.Type {
@ -84,12 +111,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
assignment := expr.Value.(AssignmentExpression)
var local Local
var ok bool
if local, ok = block.Locals[assignment.Variable]; !ok {
if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok {
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
return errors
}
valueErrors := v.validateExpression(&assignment.Value, block)
valueErrors := v.validateExpression(&assignment.Value)
if len(valueErrors) != 0 {
errors = append(errors, valueErrors...)
return errors
@ -111,7 +138,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
reference := expr.Value.(VariableReferenceExpression)
var local Local
var ok bool
if local, ok = block.Locals[reference.Variable]; !ok {
if local, ok = v.currentBlock.Locals[reference.Variable]; !ok {
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
return errors
}
@ -120,8 +147,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
case Expression_Binary:
binary := expr.Value.(BinaryExpression)
errors = append(errors, v.validateExpression(&binary.Left, block)...)
errors = append(errors, v.validateExpression(&binary.Right, block)...)
errors = append(errors, v.validateExpression(&binary.Left)...)
errors = append(errors, v.validateExpression(&binary.Right)...)
if len(errors) != 0 {
return errors
@ -180,7 +207,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
for i := range tuple.Members {
member := &tuple.Members[i]
memberErrors := v.validateExpression(member, block)
memberErrors := v.validateExpression(member)
if len(memberErrors) != 0 {
errors = append(errors, memberErrors...)
continue
@ -212,7 +239,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
}
if fc.Parameters != nil {
paramsErrors := v.validateExpression(fc.Parameters, block)
paramsErrors := v.validateExpression(fc.Parameters)
if len(paramsErrors) != 0 {
errors = append(errors, paramsErrors...)
return errors
@ -242,7 +269,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
case Expression_Negate:
neg := expr.Value.(NegateExpression)
valErrors := v.validateExpression(&neg.Value, block)
valErrors := v.validateExpression(&neg.Value)
if len(valErrors) != 0 {
errors = append(errors, valErrors...)
return errors
@ -261,8 +288,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
return errors
}
func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
errors := v.validatePotentiallyVoidExpression(expr, block)
func (v *Validator) validateExpression(expr *Expression) []error {
errors := v.validatePotentiallyVoidExpression(expr)
if len(errors) != 0 {
return errors
}
@ -274,7 +301,7 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
return errors
}
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error {
var errors []error
// TODO: support references to variables in parent block
@ -282,7 +309,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
switch stmt.Type {
case Statement_Expression:
expression := stmt.Value.(ExpressionStatement)
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...)
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...)
stmt.Value = expression
case Statement_Block:
block := stmt.Value.(BlockStatement)
@ -291,20 +318,35 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
case Statement_Return:
ret := stmt.Value.(ReturnStatement)
if ret.Value != nil {
errors = append(errors, v.validateExpression(ret.Value, block)...)
if v.currentFunction.ReturnType == nil {
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
return errors
}
errors = append(errors, v.validateExpression(ret.Value)...)
if len(errors) != 0 {
return errors
}
if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) {
errors = append(errors, v.createError("expression type does not match function return type", ret.Value.Position))
}
} else if v.currentFunction.ReturnType != nil {
errors = append(errors, v.createError("missing return value", stmt.Position))
}
case Statement_DeclareLocalVariable:
dlv := stmt.Value.(DeclareLocalVariableStatement)
if dlv.Initializer != nil {
errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
errors = append(errors, v.validateExpression(dlv.Initializer)...)
}
if _, ok := block.Locals[dlv.Variable]; ok {
if _, ok := v.currentBlock.Locals[dlv.Variable]; ok {
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
}
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
block.Locals[dlv.Variable] = local
v.currentBlock.Locals[dlv.Variable] = local
*functionLocals = append(*functionLocals, local)
// TODO: check if assignment of initializer is correct
@ -312,7 +354,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
case Statement_If:
ifS := stmt.Value.(IfStatement)
errors = append(errors, v.validateExpression(&ifS.Condition, block)...)
errors = append(errors, v.validateExpression(&ifS.Condition)...)
errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...)
if ifS.ElseBlock != nil {
@ -343,8 +385,9 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
}
for i := range block.Statements {
v.currentBlock = block
stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
}
return errors
@ -355,6 +398,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
var locals []Local
v.currentFunction = function
body := &function.Body
body.Locals = make(map[string]Local)
for _, param := range function.Parameters {