diff --git a/backend_wat.go b/backend_wat.go index b490490..699d165 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -155,7 +155,23 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { switch expr.Type { case Expression_Assignment: - // TODO + ass := expr.Value.(AssignmentExpression) + + exprWAT, err := compileExpressionWAT(ass.Value, block) + if err != nil { + return "", err + } + + cast := "" + if expr.ValueType.Type == Type_Primitive { + cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) + } + + local := strconv.Itoa(block.Locals[ass.Variable].Index) + getLocal := "local.get $" + local + "\n" + setLocal := "local.set $" + local + "\n" + + return exprWAT + cast + setLocal + getLocal, nil case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -191,6 +207,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { return "", err } + // TODO: cast produces unnecessary/wrong cast, make sure to upcast to target type castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType) if err != nil { return "", err @@ -330,6 +347,7 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { return "", err } + // TODO: make sure to upcast to target type return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil case Statement_If: ifS := stmt.Value.(IfStatement) diff --git a/example/add.lang b/example/add.lang index b7d92f6..f324eed 100644 --- a/example/add.lang +++ b/example/add.lang @@ -7,6 +7,7 @@ void a() { } (u8, u8) doNothing(u8 a, u8 b) { + a = b; return a, b; } diff --git a/main.go b/main.go index 4b3a8ee..6ea837c 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,7 @@ import ( "strings" ) -const ERROR_LOG_LINES = 10 +const ERROR_LOG_LINES = 5 func countTabs(line string) int { tabs := 0 diff --git a/parser.go b/parser.go index 1ef002d..41e7c3b 100644 --- a/parser.go +++ b/parser.go @@ -541,8 +541,43 @@ func (p *Parser) tryBinaryExpression() (*Expression, error) { return p.tryEqualityExpression() } +func (p *Parser) tryAssignmentExpression() (*Expression, error) { + pCopy := p.copy() + + lhs, err := pCopy.tryUnaryExpression() + if err != nil { + return nil, err + } + + if lhs == nil { + return nil, nil + } + + if lhs.Type != Expression_VariableReference { // TODO: allow other types + return p.tryBinaryExpression() + } + + variable := lhs.Value.(VariableReferenceExpression).Variable + op, err := pCopy.tryOperator(Operator_Equals) + if err != nil { + return nil, err + } + + if op == nil { + return p.tryBinaryExpression() + } + + expr, err := pCopy.expectExpression() + if err != nil { + return nil, err + } + + *p = pCopy + return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Variable: variable, Value: *expr}, Position: lhs.Position}, nil +} + func (p *Parser) tryExpression() (*Expression, error) { - return p.tryBinaryExpression() + return p.tryAssignmentExpression() } func (p *Parser) expectExpression() (*Expression, error) { diff --git a/validator.go b/validator.go index fae2d72..3e40c5e 100644 --- a/validator.go +++ b/validator.go @@ -131,7 +131,10 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - // TODO: check if assignment is valid + if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { + errors = append(errors, v.createError("cannot assign expression to variable type", expr.Position)) + } + expr.ValueType = &local.Type expr.Value = assignment case Expression_Literal: @@ -349,9 +352,16 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) dlv := stmt.Value.(DeclareLocalVariableStatement) if dlv.Initializer != nil { errors = append(errors, v.validateExpression(dlv.Initializer)...) + if errors != nil { + return errors + } + + if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) { + errors = append(errors, v.createError("cannot assign expression to variable type", stmt.Position)) + } } - if _, ok := v.currentBlock.Locals[dlv.Variable]; ok { + if getLocal(v.currentBlock, dlv.Variable) != nil { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } @@ -359,7 +369,6 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) v.currentBlock.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) - // TODO: check if assignment of initializer is correct stmt.Value = dlv case Statement_If: ifS := stmt.Value.(IfStatement) @@ -420,6 +429,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { errors = append(errors, v.validateBlock(body, &locals)...) + // TODO: validate that function returns return value + function.Locals = locals return errors