From 6f1490bf5a29d5176b87fbfa49f9617525165362 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Tue, 2 Apr 2024 19:43:05 +0200 Subject: [PATCH] Raw memory access --- backend_wat.go | 130 ++++++++++++++++++++++++++++++++--------------- example/test.ely | 9 ++-- parser.go | 78 +++++++++++++++++----------- validator.go | 39 ++++++++++---- 4 files changed, 172 insertions(+), 84 deletions(-) diff --git a/backend_wat.go b/backend_wat.go index 291a5c5..8860507 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -10,6 +10,9 @@ import ( type Compiler struct { Files []*ParsedFile Wasm64 bool + + CurrentBlock *Block + CurrentFunction *ParsedFunction } func getPrimitiveWATType(primitive PrimitiveType) string { @@ -79,16 +82,20 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string { panic("invalid type") } +func (c *Compiler) getAddressWATType() string { + if c.Wasm64 { + return "i64" + } else { + return "i32" + } +} + func (c *Compiler) getWATType(t Type) string { switch t.Type { case Type_Primitive: return getPrimitiveWATType(t.Value.(PrimitiveType)) case Type_Named, Type_Array: - if c.Wasm64 { - return "i64" - } else { - return "i32" - } + return c.getAddressWATType() case Type_Tuple: panic("tuple type passed to getWATType()") } @@ -164,23 +171,53 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { return "i32.wrap_i64\n" + getTypeCast(to), nil } -func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, error) { +func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (string, error) { + lhs := assignment.Lhs + + exprWAT, err := c.compileExpressionWAT(assignment.Value) + if err != nil { + return "", err + } + + switch lhs.Type { + case Expression_VariableReference: + ref := lhs.Value.(VariableReferenceExpression) + local := strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index) + return exprWAT + "local.tee $" + local + "\n", nil + case Expression_ArrayAccess: + panic("TODO") // TODO + case Expression_RawMemoryReference: + raw := lhs.Value.(RawMemoryReferenceExpression) + + local := Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)} + c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, local) + + if raw.Type.Type != Type_Primitive { + panic("TODO") //TODO + } + + addrWAT, err := c.compileExpressionWAT(raw.Address) + if err != nil { + return "", err + } + + // TODO: should leave a copy of the stored value on the stack + return addrWAT + exprWAT + + "local.tee " + strconv.Itoa(local.Index) + "\n" + + c.getWATType(raw.Type) + ".store\n" + + "local.get " + strconv.Itoa(local.Index) + "\n", nil + } + + panic("assignment expr not implemented") +} + +func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { var err error switch expr.Type { case Expression_Assignment: ass := expr.Value.(AssignmentExpression) - - exprWAT, err := c.compileExpressionWAT(ass.Value, block) - if err != nil { - return "", err - } - - local := strconv.Itoa(block.Locals[ass.Variable].Index) - getLocal := "local.get $" + local + "\n" - setLocal := "local.set $" + local + "\n" - - return exprWAT + setLocal + getLocal, nil + return c.compileAssignmentExpressionWAT(ass) case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -204,7 +241,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) } - return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil + return "local.get $" + strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index) + "\n" + cast, nil case Expression_Binary: binary := expr.Value.(BinaryExpression) @@ -212,12 +249,12 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, operandType := binary.Left.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType) - watLeft, err := c.compileExpressionWAT(binary.Left, block) + watLeft, err := c.compileExpressionWAT(binary.Left) if err != nil { return "", err } - watRight, err := c.compileExpressionWAT(binary.Right, block) + watRight, err := c.compileExpressionWAT(binary.Right) if err != nil { return "", err } @@ -264,7 +301,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, wat := "" for _, member := range tuple.Members { - memberWAT, err := c.compileExpressionWAT(member, block) + memberWAT, err := c.compileExpressionWAT(member) if err != nil { return "", err } @@ -278,7 +315,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, wat := "" if fc.Parameters != nil { - wat, err = c.compileExpressionWAT(*fc.Parameters, block) + wat, err = c.compileExpressionWAT(*fc.Parameters) if err != nil { return "", err } @@ -289,7 +326,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, neg := expr.Value.(NegateExpression) exprType := expr.ValueType.Value.(PrimitiveType) - wat, err := c.compileExpressionWAT(neg.Value, block) + wat, err := c.compileExpressionWAT(neg.Value) if err != nil { return "", err } @@ -305,7 +342,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, case Expression_Cast: cast := expr.Value.(CastExpression) - wat, err := c.compileExpressionWAT(cast.Value, block) + wat, err := c.compileExpressionWAT(cast.Value) if err != nil { return "", err } @@ -319,6 +356,19 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, } return wat + castWAT, nil + case Expression_RawMemoryReference: + raw := expr.Value.(RawMemoryReferenceExpression) + + wat, err := c.compileExpressionWAT(raw.Address) + if err != nil { + return "", err + } + + if raw.Type.Type == Type_Primitive { + wat += c.getWATType(raw.Type) + ".load\n" + } + + return wat, nil } panic("expr not implemented") @@ -328,7 +378,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) - wat, err := c.compileExpressionWAT(expr.Expression, block) + wat, err := c.compileExpressionWAT(expr.Expression) if err != nil { return "", err } @@ -353,7 +403,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) - wat, err := c.compileExpressionWAT(*ret.Value, block) + wat, err := c.compileExpressionWAT(*ret.Value) if err != nil { return "", err } @@ -365,7 +415,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er return "", nil } - wat, err := c.compileExpressionWAT(*dlv.Initializer, block) + wat, err := c.compileExpressionWAT(*dlv.Initializer) if err != nil { return "", err } @@ -374,7 +424,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er case Statement_If: ifS := stmt.Value.(IfStatement) - conditionWAT, err := c.compileExpressionWAT(ifS.Condition, block) + conditionWAT, err := c.compileExpressionWAT(ifS.Condition) if err != nil { return "", err } @@ -427,6 +477,7 @@ func (c *Compiler) compileBlockWAT(block *Block) (string, error) { blockWAT := "" for _, stmt := range block.Statements { + c.CurrentBlock = block wat, err := c.compileStatementWAT(stmt, block) if err != nil { return "", err @@ -438,7 +489,13 @@ func (c *Compiler) compileBlockWAT(block *Block) (string, error) { return blockWAT, nil } -func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) { +func (c *Compiler) compileFunctionWAT(function *ParsedFunction) (string, error) { + c.CurrentFunction = function + blockWat, err := c.compileBlockWAT(function.Body) + if err != nil { + return "", err + } + funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n" for _, local := range function.Locals { @@ -467,22 +524,15 @@ func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) { funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } - wat, err := c.compileBlockWAT(function.Body) - if err != nil { - return "", err - } - - funcWAT += wat - - return funcWAT + ")\n", nil + return funcWAT + blockWat + ")\n", nil } func (c *Compiler) compile() (string, error) { - module := "(module\n" + module := "(module (memory 0)\n" for _, file := range c.Files { - for _, function := range file.Functions { - wat, err := c.compileFunctionWAT(function) + for i := range file.Functions { + wat, err := c.compileFunctionWAT(&file.Functions[i]) if err != nil { return "", err } diff --git a/example/test.ely b/example/test.ely index 22d69bf..687b063 100644 --- a/example/test.ely +++ b/example/test.ely @@ -1,8 +1,7 @@ -void b(u64 i) { - +void b(u64 x) { + raw(u64, 0x69u64) = x; } -(u8, u16, u64) a() { - b(1u8); - return 1u8, 2u8, 3u8; +u64 a() { + return raw(u64, 0x69u64); } diff --git a/parser.go b/parser.go index 172f0cb..119a13d 100644 --- a/parser.go +++ b/parser.go @@ -95,8 +95,8 @@ type Expression struct { } type AssignmentExpression struct { - Variable string - Value Expression + Lhs Expression + Value Expression } type LiteralExpression struct { @@ -440,7 +440,7 @@ func (p *Parser) tryParanthesizedExpression() (*Expression, error) { return expr, nil } -func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { +func (p *Parser) tryPrimaryExpressionNoArrayAccess() (*Expression, error) { pCopy := p.copy() token := pCopy.peekToken() @@ -480,6 +480,11 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { if keyword == Keyword_Raw { pCopy.nextToken() + _, err := pCopy.expectSeparator(Separator_OpenParen) + if err != nil { + return nil, err + } + rawType, err := pCopy.expectType() if err != nil { return nil, err @@ -495,29 +500,13 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { return nil, err } - *p = pCopy - return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil - } - } - - if token.Type == Type_Operator { - op := token.Value.(Operator) - if op == Operator_Minus || op == Operator_Plus { - pCopy.nextToken() - expr, err := pCopy.tryUnaryExpression() + _, err = pCopy.expectSeparator(Separator_CloseParen) if err != nil { return nil, err } - if expr == nil { - return nil, nil - } - - if op == Operator_Minus { - expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position} - } - - return expr, nil + *p = pCopy + return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil } } @@ -554,10 +543,10 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { return nil, nil } -func (p *Parser) tryUnaryExpression() (*Expression, error) { +func (p *Parser) tryPrimaryExpression() (*Expression, error) { pCopy := p.copy() - expr, err := pCopy.tryUnaryExpressionNoArrayAccess() // TODO: wrong precedence + expr, err := pCopy.tryPrimaryExpressionNoArrayAccess() if err != nil { return nil, err } @@ -591,6 +580,38 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { } } +func (p *Parser) tryUnaryExpression() (*Expression, error) { + pCopy := p.copy() + + token := pCopy.peekToken() + if token == nil { + return nil, nil + } + + if token.Type == Type_Operator { + op := token.Value.(Operator) + if op == Operator_Minus || op == Operator_Plus { + pCopy.nextToken() + expr, err := pCopy.tryPrimaryExpression() + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + if op == Operator_Minus { + expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position} + } + + return expr, nil + } + } + + return nil, nil +} + func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) { left, err := opFunc() if err != nil { @@ -625,7 +646,7 @@ func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operat } func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { - return p.tryBinaryExpression0(p.tryUnaryExpression, Operator_Multiply, Operator_Divide, Operator_Modulo) + return p.tryBinaryExpression0(p.tryPrimaryExpression, Operator_Multiply, Operator_Divide, Operator_Modulo) } func (p *Parser) tryAdditiveExpression() (*Expression, error) { @@ -647,7 +668,7 @@ func (p *Parser) tryBinaryExpression() (*Expression, error) { func (p *Parser) tryAssignmentExpression() (*Expression, error) { pCopy := p.copy() - lhs, err := pCopy.tryUnaryExpression() + lhs, err := pCopy.tryPrimaryExpression() if err != nil { return nil, err } @@ -656,11 +677,10 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return nil, nil } - if lhs.Type != Expression_VariableReference { // TODO: allow other types (array access) + if lhs.Type != Expression_VariableReference && lhs.Type != Expression_ArrayAccess && lhs.Type != Expression_RawMemoryReference { return p.tryBinaryExpression() } - variable := lhs.Value.(VariableReferenceExpression).Variable op, err := pCopy.tryOperator(Operator_Equals) if err != nil { return nil, err @@ -676,7 +696,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { } *p = pCopy - return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Variable: variable, Value: *expr}, Position: lhs.Position}, nil + return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr}, Position: lhs.Position}, nil } func (p *Parser) tryExpression() (*Expression, error) { diff --git a/validator.go b/validator.go index 64fcce3..4b560de 100644 --- a/validator.go +++ b/validator.go @@ -7,6 +7,7 @@ import ( type Validator struct { Files []*ParsedFile + Wasm64 bool AllFunctions map[string]*ParsedFunction CurrentBlock *Block @@ -176,11 +177,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) - local := getLocal(v.CurrentBlock, assignment.Variable) - if local == nil { - errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) - return errors - } + + errors = append(errors, v.validateExpression(&assignment.Lhs)...) valueErrors := v.validateExpression(&assignment.Value) if len(valueErrors) != 0 { @@ -188,15 +186,15 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - if !isSameType(*assignment.Value.ValueType, local.Type) { - if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { - errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position)) + if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { + if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) { + errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position)) } - expandExpressionToType(&assignment.Value, local.Type) + expandExpressionToType(&assignment.Value, *assignment.Lhs.ValueType) } - expr.ValueType = &local.Type + expr.ValueType = assignment.Lhs.ValueType expr.Value = assignment case Expression_Literal: literal := expr.Value.(LiteralExpression) @@ -344,6 +342,27 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error expr.ValueType = neg.Value.ValueType expr.Value = neg + case Expression_RawMemoryReference: + raw := expr.Value.(RawMemoryReferenceExpression) + + addrErrors := v.validateExpression(&raw.Address) + if len(addrErrors) != 0 { + errors = append(errors, addrErrors...) + return errors + } + + if raw.Address.ValueType.Type != Type_Primitive || raw.Address.ValueType.Value.(PrimitiveType) != Primitive_U64 { + errors = append(errors, v.createError("address must evaluate to a u64 value", expr.Position)) + return errors + } + + if !v.Wasm64 { + castTo := Type{Type: Type_Primitive, Value: Primitive_U32} + raw.Address = Expression{Type: Expression_Cast, Value: CastExpression{Type: castTo, Value: raw.Address}, ValueType: &castTo, Position: raw.Address.Position} + } + + expr.ValueType = &raw.Type + expr.Value = raw default: panic("expr not implemented") }