Raw memory access

This commit is contained in:
MrLetsplay 2024-04-02 19:43:05 +02:00
parent e7cdf0c929
commit 6f1490bf5a
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 172 additions and 84 deletions

View File

@ -10,6 +10,9 @@ import (
type Compiler struct { type Compiler struct {
Files []*ParsedFile Files []*ParsedFile
Wasm64 bool Wasm64 bool
CurrentBlock *Block
CurrentFunction *ParsedFunction
} }
func getPrimitiveWATType(primitive PrimitiveType) string { func getPrimitiveWATType(primitive PrimitiveType) string {
@ -79,16 +82,20 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
panic("invalid type") panic("invalid type")
} }
func (c *Compiler) getAddressWATType() string {
if c.Wasm64 {
return "i64"
} else {
return "i32"
}
}
func (c *Compiler) getWATType(t Type) string { func (c *Compiler) getWATType(t Type) string {
switch t.Type { switch t.Type {
case Type_Primitive: case Type_Primitive:
return getPrimitiveWATType(t.Value.(PrimitiveType)) return getPrimitiveWATType(t.Value.(PrimitiveType))
case Type_Named, Type_Array: case Type_Named, Type_Array:
if c.Wasm64 { return c.getAddressWATType()
return "i64"
} else {
return "i32"
}
case Type_Tuple: case Type_Tuple:
panic("tuple type passed to getWATType()") panic("tuple type passed to getWATType()")
} }
@ -164,23 +171,53 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
return "i32.wrap_i64\n" + getTypeCast(to), nil return "i32.wrap_i64\n" + getTypeCast(to), nil
} }
func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, error) { func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpression) (string, error) {
lhs := assignment.Lhs
exprWAT, err := c.compileExpressionWAT(assignment.Value)
if err != nil {
return "", err
}
switch lhs.Type {
case Expression_VariableReference:
ref := lhs.Value.(VariableReferenceExpression)
local := strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index)
return exprWAT + "local.tee $" + local + "\n", nil
case Expression_ArrayAccess:
panic("TODO") // TODO
case Expression_RawMemoryReference:
raw := lhs.Value.(RawMemoryReferenceExpression)
local := Local{Name: "", Type: *lhs.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, local)
if raw.Type.Type != Type_Primitive {
panic("TODO") //TODO
}
addrWAT, err := c.compileExpressionWAT(raw.Address)
if err != nil {
return "", err
}
// TODO: should leave a copy of the stored value on the stack
return addrWAT + exprWAT +
"local.tee " + strconv.Itoa(local.Index) + "\n" +
c.getWATType(raw.Type) + ".store\n" +
"local.get " + strconv.Itoa(local.Index) + "\n", nil
}
panic("assignment expr not implemented")
}
func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
var err error var err error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
ass := expr.Value.(AssignmentExpression) ass := expr.Value.(AssignmentExpression)
return c.compileAssignmentExpressionWAT(ass)
exprWAT, err := c.compileExpressionWAT(ass.Value, block)
if err != nil {
return "", err
}
local := strconv.Itoa(block.Locals[ass.Variable].Index)
getLocal := "local.get $" + local + "\n"
setLocal := "local.set $" + local + "\n"
return exprWAT + setLocal + getLocal, nil
case Expression_Literal: case Expression_Literal:
lit := expr.Value.(LiteralExpression) lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type { switch lit.Literal.Type {
@ -204,7 +241,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
} }
return "local.get $" + strconv.Itoa(block.Locals[ref.Variable].Index) + "\n" + cast, nil return "local.get $" + strconv.Itoa(c.CurrentBlock.Locals[ref.Variable].Index) + "\n" + cast, nil
case Expression_Binary: case Expression_Binary:
binary := expr.Value.(BinaryExpression) binary := expr.Value.(BinaryExpression)
@ -212,12 +249,12 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
operandType := binary.Left.ValueType.Value.(PrimitiveType) operandType := binary.Left.ValueType.Value.(PrimitiveType)
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
watLeft, err := c.compileExpressionWAT(binary.Left, block) watLeft, err := c.compileExpressionWAT(binary.Left)
if err != nil { if err != nil {
return "", err return "", err
} }
watRight, err := c.compileExpressionWAT(binary.Right, block) watRight, err := c.compileExpressionWAT(binary.Right)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -264,7 +301,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
wat := "" wat := ""
for _, member := range tuple.Members { for _, member := range tuple.Members {
memberWAT, err := c.compileExpressionWAT(member, block) memberWAT, err := c.compileExpressionWAT(member)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -278,7 +315,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
wat := "" wat := ""
if fc.Parameters != nil { if fc.Parameters != nil {
wat, err = c.compileExpressionWAT(*fc.Parameters, block) wat, err = c.compileExpressionWAT(*fc.Parameters)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -289,7 +326,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
neg := expr.Value.(NegateExpression) neg := expr.Value.(NegateExpression)
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
wat, err := c.compileExpressionWAT(neg.Value, block) wat, err := c.compileExpressionWAT(neg.Value)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -305,7 +342,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
case Expression_Cast: case Expression_Cast:
cast := expr.Value.(CastExpression) cast := expr.Value.(CastExpression)
wat, err := c.compileExpressionWAT(cast.Value, block) wat, err := c.compileExpressionWAT(cast.Value)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -319,6 +356,19 @@ func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string,
} }
return wat + castWAT, nil return wat + castWAT, nil
case Expression_RawMemoryReference:
raw := expr.Value.(RawMemoryReferenceExpression)
wat, err := c.compileExpressionWAT(raw.Address)
if err != nil {
return "", err
}
if raw.Type.Type == Type_Primitive {
wat += c.getWATType(raw.Type) + ".load\n"
}
return wat, nil
} }
panic("expr not implemented") panic("expr not implemented")
@ -328,7 +378,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expr := stmt.Value.(ExpressionStatement) expr := stmt.Value.(ExpressionStatement)
wat, err := c.compileExpressionWAT(expr.Expression, block) wat, err := c.compileExpressionWAT(expr.Expression)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -353,7 +403,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
return wat, nil return wat, nil
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
wat, err := c.compileExpressionWAT(*ret.Value, block) wat, err := c.compileExpressionWAT(*ret.Value)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -365,7 +415,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
return "", nil return "", nil
} }
wat, err := c.compileExpressionWAT(*dlv.Initializer, block) wat, err := c.compileExpressionWAT(*dlv.Initializer)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -374,7 +424,7 @@ func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, er
case Statement_If: case Statement_If:
ifS := stmt.Value.(IfStatement) ifS := stmt.Value.(IfStatement)
conditionWAT, err := c.compileExpressionWAT(ifS.Condition, block) conditionWAT, err := c.compileExpressionWAT(ifS.Condition)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -427,6 +477,7 @@ func (c *Compiler) compileBlockWAT(block *Block) (string, error) {
blockWAT := "" blockWAT := ""
for _, stmt := range block.Statements { for _, stmt := range block.Statements {
c.CurrentBlock = block
wat, err := c.compileStatementWAT(stmt, block) wat, err := c.compileStatementWAT(stmt, block)
if err != nil { if err != nil {
return "", err return "", err
@ -438,7 +489,13 @@ func (c *Compiler) compileBlockWAT(block *Block) (string, error) {
return blockWAT, nil return blockWAT, nil
} }
func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) { func (c *Compiler) compileFunctionWAT(function *ParsedFunction) (string, error) {
c.CurrentFunction = function
blockWat, err := c.compileBlockWAT(function.Body)
if err != nil {
return "", err
}
funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n" funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n"
for _, local := range function.Locals { for _, local := range function.Locals {
@ -467,22 +524,15 @@ func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) {
funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
} }
wat, err := c.compileBlockWAT(function.Body) return funcWAT + blockWat + ")\n", nil
if err != nil {
return "", err
}
funcWAT += wat
return funcWAT + ")\n", nil
} }
func (c *Compiler) compile() (string, error) { func (c *Compiler) compile() (string, error) {
module := "(module\n" module := "(module (memory 0)\n"
for _, file := range c.Files { for _, file := range c.Files {
for _, function := range file.Functions { for i := range file.Functions {
wat, err := c.compileFunctionWAT(function) wat, err := c.compileFunctionWAT(&file.Functions[i])
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -1,8 +1,7 @@
void b(u64 i) { void b(u64 x) {
raw(u64, 0x69u64) = x;
} }
(u8, u16, u64) a() { u64 a() {
b(1u8); return raw(u64, 0x69u64);
return 1u8, 2u8, 3u8;
} }

View File

@ -95,8 +95,8 @@ type Expression struct {
} }
type AssignmentExpression struct { type AssignmentExpression struct {
Variable string Lhs Expression
Value Expression Value Expression
} }
type LiteralExpression struct { type LiteralExpression struct {
@ -440,7 +440,7 @@ func (p *Parser) tryParanthesizedExpression() (*Expression, error) {
return expr, nil return expr, nil
} }
func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { func (p *Parser) tryPrimaryExpressionNoArrayAccess() (*Expression, error) {
pCopy := p.copy() pCopy := p.copy()
token := pCopy.peekToken() token := pCopy.peekToken()
@ -480,6 +480,11 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) {
if keyword == Keyword_Raw { if keyword == Keyword_Raw {
pCopy.nextToken() pCopy.nextToken()
_, err := pCopy.expectSeparator(Separator_OpenParen)
if err != nil {
return nil, err
}
rawType, err := pCopy.expectType() rawType, err := pCopy.expectType()
if err != nil { if err != nil {
return nil, err return nil, err
@ -495,29 +500,13 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) {
return nil, err return nil, err
} }
*p = pCopy _, err = pCopy.expectSeparator(Separator_CloseParen)
return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil
}
}
if token.Type == Type_Operator {
op := token.Value.(Operator)
if op == Operator_Minus || op == Operator_Plus {
pCopy.nextToken()
expr, err := pCopy.tryUnaryExpression()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if expr == nil { *p = pCopy
return nil, nil return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil
}
if op == Operator_Minus {
expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position}
}
return expr, nil
} }
} }
@ -554,10 +543,10 @@ func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) {
return nil, nil return nil, nil
} }
func (p *Parser) tryUnaryExpression() (*Expression, error) { func (p *Parser) tryPrimaryExpression() (*Expression, error) {
pCopy := p.copy() pCopy := p.copy()
expr, err := pCopy.tryUnaryExpressionNoArrayAccess() // TODO: wrong precedence expr, err := pCopy.tryPrimaryExpressionNoArrayAccess()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -591,6 +580,38 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
} }
} }
func (p *Parser) tryUnaryExpression() (*Expression, error) {
pCopy := p.copy()
token := pCopy.peekToken()
if token == nil {
return nil, nil
}
if token.Type == Type_Operator {
op := token.Value.(Operator)
if op == Operator_Minus || op == Operator_Plus {
pCopy.nextToken()
expr, err := pCopy.tryPrimaryExpression()
if err != nil {
return nil, err
}
if expr == nil {
return nil, nil
}
if op == Operator_Minus {
expr = &Expression{Type: Expression_Negate, Value: NegateExpression{Value: *expr}, Position: token.Position}
}
return expr, nil
}
}
return nil, nil
}
func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) { func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) {
left, err := opFunc() left, err := opFunc()
if err != nil { if err != nil {
@ -625,7 +646,7 @@ func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operat
} }
func (p *Parser) tryMultiplicativeExpression() (*Expression, error) { func (p *Parser) tryMultiplicativeExpression() (*Expression, error) {
return p.tryBinaryExpression0(p.tryUnaryExpression, Operator_Multiply, Operator_Divide, Operator_Modulo) return p.tryBinaryExpression0(p.tryPrimaryExpression, Operator_Multiply, Operator_Divide, Operator_Modulo)
} }
func (p *Parser) tryAdditiveExpression() (*Expression, error) { func (p *Parser) tryAdditiveExpression() (*Expression, error) {
@ -647,7 +668,7 @@ func (p *Parser) tryBinaryExpression() (*Expression, error) {
func (p *Parser) tryAssignmentExpression() (*Expression, error) { func (p *Parser) tryAssignmentExpression() (*Expression, error) {
pCopy := p.copy() pCopy := p.copy()
lhs, err := pCopy.tryUnaryExpression() lhs, err := pCopy.tryPrimaryExpression()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -656,11 +677,10 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) {
return nil, nil return nil, nil
} }
if lhs.Type != Expression_VariableReference { // TODO: allow other types (array access) if lhs.Type != Expression_VariableReference && lhs.Type != Expression_ArrayAccess && lhs.Type != Expression_RawMemoryReference {
return p.tryBinaryExpression() return p.tryBinaryExpression()
} }
variable := lhs.Value.(VariableReferenceExpression).Variable
op, err := pCopy.tryOperator(Operator_Equals) op, err := pCopy.tryOperator(Operator_Equals)
if err != nil { if err != nil {
return nil, err return nil, err
@ -676,7 +696,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) {
} }
*p = pCopy *p = pCopy
return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Variable: variable, Value: *expr}, Position: lhs.Position}, nil return &Expression{Type: Expression_Assignment, Value: AssignmentExpression{Lhs: *lhs, Value: *expr}, Position: lhs.Position}, nil
} }
func (p *Parser) tryExpression() (*Expression, error) { func (p *Parser) tryExpression() (*Expression, error) {

View File

@ -7,6 +7,7 @@ import (
type Validator struct { type Validator struct {
Files []*ParsedFile Files []*ParsedFile
Wasm64 bool
AllFunctions map[string]*ParsedFunction AllFunctions map[string]*ParsedFunction
CurrentBlock *Block CurrentBlock *Block
@ -176,11 +177,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
assignment := expr.Value.(AssignmentExpression) assignment := expr.Value.(AssignmentExpression)
local := getLocal(v.CurrentBlock, assignment.Variable)
if local == nil { errors = append(errors, v.validateExpression(&assignment.Lhs)...)
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
return errors
}
valueErrors := v.validateExpression(&assignment.Value) valueErrors := v.validateExpression(&assignment.Value)
if len(valueErrors) != 0 { if len(valueErrors) != 0 {
@ -188,15 +186,15 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
if !isSameType(*assignment.Value.ValueType, local.Type) { if !isSameType(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { if !isTypeExpandableTo(*assignment.Value.ValueType, *assignment.Lhs.ValueType) {
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position)) errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, *assignment.Lhs.ValueType), expr.Position))
} }
expandExpressionToType(&assignment.Value, local.Type) expandExpressionToType(&assignment.Value, *assignment.Lhs.ValueType)
} }
expr.ValueType = &local.Type expr.ValueType = assignment.Lhs.ValueType
expr.Value = assignment expr.Value = assignment
case Expression_Literal: case Expression_Literal:
literal := expr.Value.(LiteralExpression) literal := expr.Value.(LiteralExpression)
@ -344,6 +342,27 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
expr.ValueType = neg.Value.ValueType expr.ValueType = neg.Value.ValueType
expr.Value = neg expr.Value = neg
case Expression_RawMemoryReference:
raw := expr.Value.(RawMemoryReferenceExpression)
addrErrors := v.validateExpression(&raw.Address)
if len(addrErrors) != 0 {
errors = append(errors, addrErrors...)
return errors
}
if raw.Address.ValueType.Type != Type_Primitive || raw.Address.ValueType.Value.(PrimitiveType) != Primitive_U64 {
errors = append(errors, v.createError("address must evaluate to a u64 value", expr.Position))
return errors
}
if !v.Wasm64 {
castTo := Type{Type: Type_Primitive, Value: Primitive_U32}
raw.Address = Expression{Type: Expression_Cast, Value: CastExpression{Type: castTo, Value: raw.Address}, ValueType: &castTo, Position: raw.Address.Position}
}
expr.ValueType = &raw.Type
expr.Value = raw
default: default:
panic("expr not implemented") panic("expr not implemented")
} }