Array assignment

This commit is contained in:
MrLetsplay 2024-10-31 20:50:09 +01:00
parent c119a077d6
commit 3af1535515
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
7 changed files with 144 additions and 30 deletions

View File

@ -17,3 +17,4 @@ The Elysium programming language.
- [ ] Memory allocation, Heap allocator - [ ] Memory allocation, Heap allocator
- [ ] Garbage collector - [ ] Garbage collector
- [ ] Support for wasm64 - [ ] Support for wasm64
- [ ] Imported functions (from JS)

View File

@ -8,9 +8,12 @@ import (
"unicode" "unicode"
) )
const COMPILE_OPTION_NO_BOUNDS_CHECK = "no_bounds_check"
type Compiler struct { type Compiler struct {
Files []*ParsedFile Files []*ParsedFile
Wasm64 bool Wasm64 bool
CompileOptions map[string]string
CurrentBlock *Block CurrentBlock *Block
CurrentFunction *ParsedFunction CurrentFunction *ParsedFunction
@ -58,7 +61,7 @@ func getTypeCast(primitive PrimitiveType) string {
case Primitive_U16: case Primitive_U16:
return "i32.const 65535\ni32.and\n" return "i32.const 65535\ni32.and\n"
case Primitive_Bool: case Primitive_Bool:
return "i32.const 1\ni32.and\n" return "i32.const 0\ni32.ne\n"
} }
return "" return ""
@ -222,7 +225,78 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio
local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index)
return exprWAT + "local.tee $" + local + "\n", nil return exprWAT + "local.tee $" + local + "\n", nil
case Expression_ArrayAccess: case Expression_ArrayAccess:
panic("TODO") // TODO array := lhs.Value.(ArrayAccessExpression)
localArray := Local{Name: "", Type: Type{Type: Type_Primitive, Value: c.getEffectiveAddressType(), Position: unknownPosition()}, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localArray)
localIndex := Local{Name: "", Type: Type{Type: Type_Primitive, Value: c.getEffectiveAddressType(), Position: unknownPosition()}, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localIndex)
localElement := Local{Name: "", Type: *assignment.Value.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)}
c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localElement)
arrayWAT, err := c.compileExpressionWAT(array.Array)
if err != nil {
return "", err
}
arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n"
indexWAT, err := c.compileExpressionWAT(array.Index)
if err != nil {
return "", err
}
if !c.Wasm64 {
cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32)
if err != nil {
return "", err
}
indexWAT += cast
}
indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n"
wat := arrayWAT + indexWAT
if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok {
// Error if index < 0
wat += "block\n"
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n"
wat += c.getAddressWATType() + ".const 0\n"
wat += c.getAddressWATType() + ".ge_s\n"
wat += "br_if 0\n"
wat += "call $__builtin_panic\n"
wat += "end\n"
// Error if index >= array length
wat += "block\n"
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n"
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n"
wat += "i32.load\n" // Load array length
wat += c.getAddressWATType() + ".lt_s\n"
wat += "br_if 0\n"
wat += "call $__builtin_panic\n"
wat += "end\n"
}
elementType := array.Array.ValueType.Value.(ArrayType).ElementType
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n"
wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n"
wat += c.getAddressWATType() + ".mul\n"
wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n"
wat += c.getAddressWATType() + ".add\n"
wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length
wat += c.getAddressWATType() + ".add\n"
wat += exprWAT
wat += "local.tee $" + strconv.Itoa(localElement.Index) + "\n"
wat += c.getWATType(elementType) + ".store\n" // TODO: use load8/load16(_s/u) for smaller types
wat += "local.get $" + strconv.Itoa(localElement.Index) + "\n"
return wat, nil
case Expression_RawMemoryReference: case Expression_RawMemoryReference:
raw := lhs.Value.(RawMemoryReferenceExpression) raw := lhs.Value.(RawMemoryReferenceExpression)
@ -240,9 +314,9 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio
// TODO: should leave a copy of the stored value on the stack // TODO: should leave a copy of the stored value on the stack
return addrWAT + exprWAT + return addrWAT + exprWAT +
"local.tee " + strconv.Itoa(local.Index) + "\n" + "local.tee $" + strconv.Itoa(local.Index) + "\n" +
c.getWATType(raw.Type) + ".store\n" + c.getWATType(raw.Type) + ".store\n" +
"local.get " + strconv.Itoa(local.Index) + "\n", nil "local.get $" + strconv.Itoa(local.Index) + "\n", nil
} }
panic("assignment expr not implemented") panic("assignment expr not implemented")
@ -555,11 +629,12 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
wat := arrayWAT + indexWAT wat := arrayWAT + indexWAT
// Error if index <= 0 if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok {
// Error if index < 0
wat += "block\n" wat += "block\n"
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n"
wat += c.getAddressWATType() + ".const 0\n" wat += c.getAddressWATType() + ".const 0\n"
wat += c.getAddressWATType() + ".gt_s\n" wat += c.getAddressWATType() + ".ge_s\n"
wat += "br_if 0\n" wat += "br_if 0\n"
wat += "call $__builtin_panic\n" wat += "call $__builtin_panic\n"
wat += "end\n" wat += "end\n"
@ -573,6 +648,7 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) {
wat += "br_if 0\n" wat += "br_if 0\n"
wat += "call $__builtin_panic\n" wat += "call $__builtin_panic\n"
wat += "end\n" wat += "end\n"
}
elementType := array.Array.ValueType.Value.(ArrayType).ElementType elementType := array.Array.ValueType.Value.(ArrayType).ElementType
wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n"

27
example/array.ely Normal file
View File

@ -0,0 +1,27 @@
u64 array() {
if(__builtin_memory_size() == 0u64) {
__builtin_memory_grow(1u64);
}
raw(u32, 0x0u64) = 1u32;
raw(u64, 0x04u64) = 69u64;
u64[] array = raw(u64[], 0x0u64);
array[0] = 1u64;
array[1] = 2u64;
return array[0] + array[1];
}
void array2() {
u64[] array = raw(u64[], 0x0u64);
array[0] = 1u64;
}
u64[] arrayReturn() {
if(__builtin_memory_size() == 0u64) {
__builtin_memory_grow(1u64);
}
return raw(u64[], 0x0u64);
}

View File

@ -38,7 +38,3 @@ u64 assign(u64 a) {
a += 1u64; a += 1u64;
return raw(u64, a) += 2u64; return raw(u64, a) += 2u64;
} }
u64 test() {
return raw(u64[], 0x0u8)[0];
}

13
main.go
View File

@ -90,6 +90,7 @@ func main() {
generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM") generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM")
wasm64 := flag.Bool("wasm64", false, "Use 64-bit memory (may not be supported in all browsers)") wasm64 := flag.Bool("wasm64", false, "Use 64-bit memory (may not be supported in all browsers)")
includeStdlib := flag.Bool("stdlib", true, "Include the standard library") includeStdlib := flag.Bool("stdlib", true, "Include the standard library")
compileOptions := flag.String("compileOptions", "", "The compile options (key=value,key2=value2,key3)")
flag.Parse() flag.Parse()
if len(os.Args) < 2 { if len(os.Args) < 2 {
@ -169,7 +170,17 @@ func main() {
// log.Printf("Validated:\n%+#v\n\n", parsedFiles) // log.Printf("Validated:\n%+#v\n\n", parsedFiles)
compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64} compileOptionsMap := make(map[string]string, 0)
for _, value := range strings.Split(*compileOptions, ",") {
kv := strings.Split(value, "=")
if len(kv) == 1 {
compileOptionsMap[kv[0]] = ""
} else {
compileOptionsMap[kv[0]] = kv[1]
}
}
compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64, CompileOptions: compileOptionsMap}
wat, err := compiler.compile() wat, err := compiler.compile()
if err != nil { if err != nil {
if c, ok := err.(CompilerError); ok { if c, ok := err.(CompilerError); ok {

View File

@ -206,6 +206,7 @@ type ParsedFunction struct {
ReturnType *Type ReturnType *Type
Body *Block Body *Block
Locals []Local // All of the locals of the function, ordered by their index Locals []Local // All of the locals of the function, ordered by their index
Position TokenPosition
} }
type Import struct { type Import struct {
@ -1101,7 +1102,7 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) {
return nil, err return nil, err
} }
return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body}, nil return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body, Position: tok.Position}, nil
} }
func (p *Parser) parseFile() (*ParsedFile, error) { func (p *Parser) parseFile() (*ParsedFile, error) {

View File

@ -25,12 +25,14 @@ var builtinFunctions map[string]*ParsedFunction = map[string]*ParsedFunction{
FullName: "builtin." + BUILTIN_MEMORY_GROW, FullName: "builtin." + BUILTIN_MEMORY_GROW,
Parameters: []ParsedParameter{{Name: "memory", Type: Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}}}, Parameters: []ParsedParameter{{Name: "memory", Type: Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}}},
ReturnType: &Type{Type: Type_Primitive, Value: Primitive_I64, Position: unknownPosition()}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_I64, Position: unknownPosition()},
Position: unknownPosition(),
}, },
BUILTIN_MEMORY_SIZE: { BUILTIN_MEMORY_SIZE: {
Name: BUILTIN_MEMORY_SIZE, Name: BUILTIN_MEMORY_SIZE,
FullName: "builtin." + BUILTIN_MEMORY_SIZE, FullName: "builtin." + BUILTIN_MEMORY_SIZE,
Parameters: []ParsedParameter{}, Parameters: []ParsedParameter{},
ReturnType: &Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()},
Position: unknownPosition(),
}, },
} }
@ -640,7 +642,7 @@ func (v *Validator) validate() []error {
function.FullName = fullFunctionName function.FullName = fullFunctionName
if _, exists := v.AllFunctions[fullFunctionName]; exists { if _, exists := v.AllFunctions[fullFunctionName]; exists {
errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.Position))
} }
v.AllFunctions[fullFunctionName] = function v.AllFunctions[fullFunctionName] = function