Compile if statements
This commit is contained in:
parent
c0b9ee086a
commit
fa63fee64d
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"unicode"
|
||||
)
|
||||
@ -181,29 +180,28 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
|
||||
|
||||
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil
|
||||
case Expression_Binary:
|
||||
arith := expr.Value.(BinaryExpression)
|
||||
|
||||
log.Printf("%+#v", arith)
|
||||
binary := expr.Value.(BinaryExpression)
|
||||
|
||||
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
|
||||
resultType := binary.ResultType.Value.(PrimitiveType)
|
||||
exprType := expr.ValueType.Value.(PrimitiveType)
|
||||
|
||||
watLeft, err := compileExpressionWAT(arith.Left, block)
|
||||
watLeft, err := compileExpressionWAT(binary.Left, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
castLeft, err := castPrimitiveWAT(arith.Left.ValueType.Value.(PrimitiveType), exprType)
|
||||
castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
watRight, err := compileExpressionWAT(arith.Right, block)
|
||||
watRight, err := compileExpressionWAT(binary.Right, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
castRight, err := castPrimitiveWAT(arith.Right.ValueType.Value.(PrimitiveType), exprType)
|
||||
castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -211,23 +209,35 @@ func compileExpressionWAT(expr Expression, block Block) (string, error) {
|
||||
op := ""
|
||||
|
||||
suffix := ""
|
||||
if isUnsignedInt(exprType) {
|
||||
if isUnsignedInt(resultType) {
|
||||
suffix = "u"
|
||||
} else {
|
||||
suffix = "s"
|
||||
}
|
||||
|
||||
switch arith.Operation {
|
||||
switch binary.Operation {
|
||||
case Operation_Add:
|
||||
op = getPrimitiveWATType(exprType) + ".add\n"
|
||||
op = getPrimitiveWATType(resultType) + ".add\n"
|
||||
case Operation_Sub:
|
||||
op = getPrimitiveWATType(exprType) + ".sub\n"
|
||||
op = getPrimitiveWATType(resultType) + ".sub\n"
|
||||
case Operation_Mul:
|
||||
op = getPrimitiveWATType(exprType) + ".mul\n"
|
||||
op = getPrimitiveWATType(resultType) + ".mul\n"
|
||||
case Operation_Div:
|
||||
op = getPrimitiveWATType(exprType) + ".div_" + suffix + "\n"
|
||||
op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n"
|
||||
case Operation_Mod:
|
||||
op = getPrimitiveWATType(exprType) + ".rem_" + suffix + "\n"
|
||||
op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n"
|
||||
case Operation_Greater:
|
||||
op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n"
|
||||
case Operation_Less:
|
||||
op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n"
|
||||
case Operation_GreaterEquals:
|
||||
op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n"
|
||||
case Operation_LessEquals:
|
||||
op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n"
|
||||
case Operation_NotEquals:
|
||||
op = getPrimitiveWATType(resultType) + ".ne\n"
|
||||
case Operation_Equals:
|
||||
op = getPrimitiveWATType(resultType) + ".eq\n"
|
||||
default:
|
||||
panic("operation not implemented")
|
||||
}
|
||||
@ -306,6 +316,8 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// TODO: upcast to return type for non-primitive types
|
||||
|
||||
return wat + "return\n", nil
|
||||
case Statement_DeclareLocalVariable:
|
||||
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
||||
@ -319,6 +331,53 @@ func compileStatementWAT(stmt Statement, block Block) (string, error) {
|
||||
}
|
||||
|
||||
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
|
||||
case Statement_If:
|
||||
ifS := stmt.Value.(IfStatement)
|
||||
|
||||
conditionWAT, err := compileExpressionWAT(ifS.Condition, block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
wat := ""
|
||||
|
||||
if ifS.ElseBlock != nil {
|
||||
wat += "block\n"
|
||||
}
|
||||
|
||||
// condition
|
||||
wat += "block\n"
|
||||
|
||||
wat += conditionWAT
|
||||
wat += "i32.eqz\n" // logical not
|
||||
wat += "br_if 0\n"
|
||||
|
||||
// condition is true
|
||||
wat += condBlockWAT
|
||||
|
||||
if ifS.ElseBlock != nil {
|
||||
wat += "br 1\n" // jump over else block
|
||||
}
|
||||
|
||||
wat += "end\n"
|
||||
|
||||
if ifS.ElseBlock != nil {
|
||||
// condition is false
|
||||
elseWAT, err := compileBlockWAT(*ifS.ElseBlock)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
wat += elseWAT
|
||||
wat += "end\n"
|
||||
}
|
||||
|
||||
return wat, nil
|
||||
}
|
||||
|
||||
panic("stmt not implemented")
|
||||
|
@ -2,14 +2,6 @@ u64 add(u8 a, u8 b) {
|
||||
return add(a - 1u8, a);
|
||||
}
|
||||
|
||||
u64 add2(u64 a, u64 b) {
|
||||
if(a == b) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return a + b;
|
||||
}
|
||||
|
||||
void a() {
|
||||
|
||||
}
|
||||
@ -17,3 +9,19 @@ void a() {
|
||||
(u8, u8) doNothing(u8 a, u8 b) {
|
||||
return a, b;
|
||||
}
|
||||
|
||||
u64 doStuff(u64 a, u64 b) {
|
||||
if(a > b) {
|
||||
return 1u64;
|
||||
}
|
||||
|
||||
return 2u64;
|
||||
}
|
||||
|
||||
u64 fib(u64 n) {
|
||||
if(n <= 1u64) {
|
||||
return 1u64;
|
||||
}
|
||||
|
||||
return fib(n - 1u64) + fib(n - 2u64);
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ type BinaryExpression struct {
|
||||
Operation Operation
|
||||
Left Expression
|
||||
Right Expression
|
||||
ResultType *Type
|
||||
ResultType *Type // Type to expand the operands to before performing the operation
|
||||
}
|
||||
|
||||
type TupleExpression struct {
|
||||
|
85
validator.go
85
validator.go
@ -1,18 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Validator struct {
|
||||
file *ParsedFile
|
||||
|
||||
currentBlock *Block
|
||||
currentFunction *ParsedFunction
|
||||
}
|
||||
|
||||
func isTypeExpandableTo(from Type, to Type) bool {
|
||||
if from.Type == Type_Primitive && to.Type == Type_Primitive {
|
||||
if from.Type != to.Type {
|
||||
// cannot convert between primitive, named, array and tuple types
|
||||
return false
|
||||
}
|
||||
|
||||
if from.Type == Type_Primitive {
|
||||
return isPrimitiveTypeExpandableTo(from.Value.(PrimitiveType), to.Value.(PrimitiveType))
|
||||
}
|
||||
|
||||
if from.Type == Type_Tuple {
|
||||
fromT := from.Value.(TupleType)
|
||||
toT := to.Value.(TupleType)
|
||||
|
||||
if len(fromT.Types) != len(toT.Types) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(fromT.Types); i++ {
|
||||
if !isTypeExpandableTo(fromT.Types[i], toT.Types[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
log.Printf("%+#v %+#v", from, to)
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@ -76,7 +103,7 @@ func (v *Validator) getArithmeticResultType(expr *Expression, left PrimitiveType
|
||||
return InvalidValue, v.createError("cannot use these types in an arithmetic expression without an explicit cast", expr.Position) // TODO: include type names in error
|
||||
}
|
||||
|
||||
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *Block) []error {
|
||||
func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error {
|
||||
var errors []error
|
||||
|
||||
switch expr.Type {
|
||||
@ -84,12 +111,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
assignment := expr.Value.(AssignmentExpression)
|
||||
var local Local
|
||||
var ok bool
|
||||
if local, ok = block.Locals[assignment.Variable]; !ok {
|
||||
if local, ok = v.currentBlock.Locals[assignment.Variable]; !ok {
|
||||
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
|
||||
return errors
|
||||
}
|
||||
|
||||
valueErrors := v.validateExpression(&assignment.Value, block)
|
||||
valueErrors := v.validateExpression(&assignment.Value)
|
||||
if len(valueErrors) != 0 {
|
||||
errors = append(errors, valueErrors...)
|
||||
return errors
|
||||
@ -111,7 +138,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
reference := expr.Value.(VariableReferenceExpression)
|
||||
var local Local
|
||||
var ok bool
|
||||
if local, ok = block.Locals[reference.Variable]; !ok {
|
||||
if local, ok = v.currentBlock.Locals[reference.Variable]; !ok {
|
||||
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
|
||||
return errors
|
||||
}
|
||||
@ -120,8 +147,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
case Expression_Binary:
|
||||
binary := expr.Value.(BinaryExpression)
|
||||
|
||||
errors = append(errors, v.validateExpression(&binary.Left, block)...)
|
||||
errors = append(errors, v.validateExpression(&binary.Right, block)...)
|
||||
errors = append(errors, v.validateExpression(&binary.Left)...)
|
||||
errors = append(errors, v.validateExpression(&binary.Right)...)
|
||||
|
||||
if len(errors) != 0 {
|
||||
return errors
|
||||
@ -180,7 +207,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
for i := range tuple.Members {
|
||||
member := &tuple.Members[i]
|
||||
|
||||
memberErrors := v.validateExpression(member, block)
|
||||
memberErrors := v.validateExpression(member)
|
||||
if len(memberErrors) != 0 {
|
||||
errors = append(errors, memberErrors...)
|
||||
continue
|
||||
@ -212,7 +239,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
}
|
||||
|
||||
if fc.Parameters != nil {
|
||||
paramsErrors := v.validateExpression(fc.Parameters, block)
|
||||
paramsErrors := v.validateExpression(fc.Parameters)
|
||||
if len(paramsErrors) != 0 {
|
||||
errors = append(errors, paramsErrors...)
|
||||
return errors
|
||||
@ -242,7 +269,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
case Expression_Negate:
|
||||
neg := expr.Value.(NegateExpression)
|
||||
|
||||
valErrors := v.validateExpression(&neg.Value, block)
|
||||
valErrors := v.validateExpression(&neg.Value)
|
||||
if len(valErrors) != 0 {
|
||||
errors = append(errors, valErrors...)
|
||||
return errors
|
||||
@ -261,8 +288,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression, block *B
|
||||
return errors
|
||||
}
|
||||
|
||||
func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
|
||||
errors := v.validatePotentiallyVoidExpression(expr, block)
|
||||
func (v *Validator) validateExpression(expr *Expression) []error {
|
||||
errors := v.validatePotentiallyVoidExpression(expr)
|
||||
if len(errors) != 0 {
|
||||
return errors
|
||||
}
|
||||
@ -274,7 +301,7 @@ func (v *Validator) validateExpression(expr *Expression, block *Block) []error {
|
||||
return errors
|
||||
}
|
||||
|
||||
func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLocals *[]Local) []error {
|
||||
func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) []error {
|
||||
var errors []error
|
||||
|
||||
// TODO: support references to variables in parent block
|
||||
@ -282,7 +309,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
|
||||
switch stmt.Type {
|
||||
case Statement_Expression:
|
||||
expression := stmt.Value.(ExpressionStatement)
|
||||
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression, block)...)
|
||||
errors = append(errors, v.validatePotentiallyVoidExpression(&expression.Expression)...)
|
||||
stmt.Value = expression
|
||||
case Statement_Block:
|
||||
block := stmt.Value.(BlockStatement)
|
||||
@ -291,20 +318,35 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
|
||||
case Statement_Return:
|
||||
ret := stmt.Value.(ReturnStatement)
|
||||
if ret.Value != nil {
|
||||
errors = append(errors, v.validateExpression(ret.Value, block)...)
|
||||
if v.currentFunction.ReturnType == nil {
|
||||
errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
|
||||
return errors
|
||||
}
|
||||
|
||||
errors = append(errors, v.validateExpression(ret.Value)...)
|
||||
|
||||
if len(errors) != 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) {
|
||||
errors = append(errors, v.createError("expression type does not match function return type", ret.Value.Position))
|
||||
}
|
||||
} else if v.currentFunction.ReturnType != nil {
|
||||
errors = append(errors, v.createError("missing return value", stmt.Position))
|
||||
}
|
||||
case Statement_DeclareLocalVariable:
|
||||
dlv := stmt.Value.(DeclareLocalVariableStatement)
|
||||
if dlv.Initializer != nil {
|
||||
errors = append(errors, v.validateExpression(dlv.Initializer, block)...)
|
||||
errors = append(errors, v.validateExpression(dlv.Initializer)...)
|
||||
}
|
||||
|
||||
if _, ok := block.Locals[dlv.Variable]; ok {
|
||||
if _, ok := v.currentBlock.Locals[dlv.Variable]; ok {
|
||||
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
|
||||
}
|
||||
|
||||
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
|
||||
block.Locals[dlv.Variable] = local
|
||||
v.currentBlock.Locals[dlv.Variable] = local
|
||||
*functionLocals = append(*functionLocals, local)
|
||||
|
||||
// TODO: check if assignment of initializer is correct
|
||||
@ -312,7 +354,7 @@ func (v *Validator) validateStatement(stmt *Statement, block *Block, functionLoc
|
||||
case Statement_If:
|
||||
ifS := stmt.Value.(IfStatement)
|
||||
|
||||
errors = append(errors, v.validateExpression(&ifS.Condition, block)...)
|
||||
errors = append(errors, v.validateExpression(&ifS.Condition)...)
|
||||
errors = append(errors, v.validateBlock(&ifS.ConditionalBlock, functionLocals)...)
|
||||
|
||||
if ifS.ElseBlock != nil {
|
||||
@ -343,8 +385,9 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
|
||||
}
|
||||
|
||||
for i := range block.Statements {
|
||||
v.currentBlock = block
|
||||
stmt := &block.Statements[i]
|
||||
errors = append(errors, v.validateStatement(stmt, block, functionLocals)...)
|
||||
errors = append(errors, v.validateStatement(stmt, functionLocals)...)
|
||||
}
|
||||
|
||||
return errors
|
||||
@ -355,6 +398,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
|
||||
|
||||
var locals []Local
|
||||
|
||||
v.currentFunction = function
|
||||
|
||||
body := &function.Body
|
||||
body.Locals = make(map[string]Local)
|
||||
for _, param := range function.Parameters {
|
||||
|
Loading…
Reference in New Issue
Block a user