From 3af1535515051b9a7b172417308775b5adc3caaf Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Thu, 31 Oct 2024 20:50:09 +0100 Subject: [PATCH] Array assignment --- README.md | 1 + backend_wat.go | 122 +++++++++++++++++++++++++++++++++++++--------- example/array.ely | 27 ++++++++++ example/test.ely | 4 -- main.go | 13 ++++- parser.go | 3 +- validator.go | 4 +- 7 files changed, 144 insertions(+), 30 deletions(-) create mode 100644 example/array.ely diff --git a/README.md b/README.md index 2d9d1db..ee37e95 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,4 @@ The Elysium programming language. - [ ] Memory allocation, Heap allocator - [ ] Garbage collector - [ ] Support for wasm64 +- [ ] Imported functions (from JS) diff --git a/backend_wat.go b/backend_wat.go index e6acab8..9817357 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -8,9 +8,12 @@ import ( "unicode" ) +const COMPILE_OPTION_NO_BOUNDS_CHECK = "no_bounds_check" + type Compiler struct { - Files []*ParsedFile - Wasm64 bool + Files []*ParsedFile + Wasm64 bool + CompileOptions map[string]string CurrentBlock *Block CurrentFunction *ParsedFunction @@ -58,7 +61,7 @@ func getTypeCast(primitive PrimitiveType) string { case Primitive_U16: return "i32.const 65535\ni32.and\n" case Primitive_Bool: - return "i32.const 1\ni32.and\n" + return "i32.const 0\ni32.ne\n" } return "" @@ -222,7 +225,78 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio local := strconv.Itoa(getLocal(c.CurrentBlock, ref.Variable).Index) return exprWAT + "local.tee $" + local + "\n", nil case Expression_ArrayAccess: - panic("TODO") // TODO + array := lhs.Value.(ArrayAccessExpression) + + localArray := Local{Name: "", Type: Type{Type: Type_Primitive, Value: c.getEffectiveAddressType(), Position: unknownPosition()}, IsParameter: false, Index: len(c.CurrentFunction.Locals)} + c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localArray) + + localIndex := Local{Name: "", Type: Type{Type: Type_Primitive, Value: c.getEffectiveAddressType(), Position: unknownPosition()}, IsParameter: false, Index: len(c.CurrentFunction.Locals)} + c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localIndex) + + localElement := Local{Name: "", Type: *assignment.Value.ValueType, IsParameter: false, Index: len(c.CurrentFunction.Locals)} + c.CurrentFunction.Locals = append(c.CurrentFunction.Locals, localElement) + + arrayWAT, err := c.compileExpressionWAT(array.Array) + if err != nil { + return "", err + } + + arrayWAT += "local.set $" + strconv.Itoa(localArray.Index) + "\n" + + indexWAT, err := c.compileExpressionWAT(array.Index) + if err != nil { + return "", err + } + + if !c.Wasm64 { + cast, err := castPrimitiveWAT(Primitive_I64, Primitive_I32) + if err != nil { + return "", err + } + + indexWAT += cast + } + + indexWAT += "local.set $" + strconv.Itoa(localIndex.Index) + "\n" + + wat := arrayWAT + indexWAT + + if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { + // Error if index < 0 + wat += "block\n" + wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" + wat += c.getAddressWATType() + ".const 0\n" + wat += c.getAddressWATType() + ".ge_s\n" + wat += "br_if 0\n" + wat += "call $__builtin_panic\n" + wat += "end\n" + + // Error if index >= array length + wat += "block\n" + wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" + wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" + wat += "i32.load\n" // Load array length + wat += c.getAddressWATType() + ".lt_s\n" + wat += "br_if 0\n" + wat += "call $__builtin_panic\n" + wat += "end\n" + } + + elementType := array.Array.ValueType.Value.(ArrayType).ElementType + wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" + wat += c.getAddressWATType() + ".const " + strconv.Itoa(c.getTypeSizeBytes(elementType)) + "\n" + wat += c.getAddressWATType() + ".mul\n" + wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" + wat += c.getAddressWATType() + ".add\n" + wat += c.getAddressWATType() + ".const 4\n" // first 4 bytes = length + wat += c.getAddressWATType() + ".add\n" + + wat += exprWAT + wat += "local.tee $" + strconv.Itoa(localElement.Index) + "\n" + wat += c.getWATType(elementType) + ".store\n" // TODO: use load8/load16(_s/u) for smaller types + wat += "local.get $" + strconv.Itoa(localElement.Index) + "\n" + + return wat, nil case Expression_RawMemoryReference: raw := lhs.Value.(RawMemoryReferenceExpression) @@ -240,9 +314,9 @@ func (c *Compiler) compileAssignmentExpressionWAT(assignment AssignmentExpressio // TODO: should leave a copy of the stored value on the stack return addrWAT + exprWAT + - "local.tee " + strconv.Itoa(local.Index) + "\n" + + "local.tee $" + strconv.Itoa(local.Index) + "\n" + c.getWATType(raw.Type) + ".store\n" + - "local.get " + strconv.Itoa(local.Index) + "\n", nil + "local.get $" + strconv.Itoa(local.Index) + "\n", nil } panic("assignment expr not implemented") @@ -555,24 +629,26 @@ func (c *Compiler) compileExpressionWAT(expr Expression) (string, error) { wat := arrayWAT + indexWAT - // Error if index <= 0 - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += c.getAddressWATType() + ".const 0\n" - wat += c.getAddressWATType() + ".gt_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + if _, ok := c.CompileOptions[COMPILE_OPTION_NO_BOUNDS_CHECK]; !ok { + // Error if index < 0 + wat += "block\n" + wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" + wat += c.getAddressWATType() + ".const 0\n" + wat += c.getAddressWATType() + ".ge_s\n" + wat += "br_if 0\n" + wat += "call $__builtin_panic\n" + wat += "end\n" - // Error if index >= array length - wat += "block\n" - wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" - wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" - wat += "i32.load\n" // Load array length - wat += c.getAddressWATType() + ".lt_s\n" - wat += "br_if 0\n" - wat += "call $__builtin_panic\n" - wat += "end\n" + // Error if index >= array length + wat += "block\n" + wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" + wat += "local.get $" + strconv.Itoa(localArray.Index) + "\n" + wat += "i32.load\n" // Load array length + wat += c.getAddressWATType() + ".lt_s\n" + wat += "br_if 0\n" + wat += "call $__builtin_panic\n" + wat += "end\n" + } elementType := array.Array.ValueType.Value.(ArrayType).ElementType wat += "local.get $" + strconv.Itoa(localIndex.Index) + "\n" diff --git a/example/array.ely b/example/array.ely new file mode 100644 index 0000000..1a24359 --- /dev/null +++ b/example/array.ely @@ -0,0 +1,27 @@ +u64 array() { + if(__builtin_memory_size() == 0u64) { + __builtin_memory_grow(1u64); + } + + raw(u32, 0x0u64) = 1u32; + raw(u64, 0x04u64) = 69u64; + + u64[] array = raw(u64[], 0x0u64); + array[0] = 1u64; + array[1] = 2u64; + + return array[0] + array[1]; +} + +void array2() { + u64[] array = raw(u64[], 0x0u64); + array[0] = 1u64; +} + +u64[] arrayReturn() { + if(__builtin_memory_size() == 0u64) { + __builtin_memory_grow(1u64); + } + + return raw(u64[], 0x0u64); +} diff --git a/example/test.ely b/example/test.ely index 2591e34..4cb5340 100644 --- a/example/test.ely +++ b/example/test.ely @@ -38,7 +38,3 @@ u64 assign(u64 a) { a += 1u64; return raw(u64, a) += 2u64; } - -u64 test() { - return raw(u64[], 0x0u8)[0]; -} diff --git a/main.go b/main.go index 6befef9..f1d24a2 100644 --- a/main.go +++ b/main.go @@ -90,6 +90,7 @@ func main() { generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM") wasm64 := flag.Bool("wasm64", false, "Use 64-bit memory (may not be supported in all browsers)") includeStdlib := flag.Bool("stdlib", true, "Include the standard library") + compileOptions := flag.String("compileOptions", "", "The compile options (key=value,key2=value2,key3)") flag.Parse() if len(os.Args) < 2 { @@ -169,7 +170,17 @@ func main() { // log.Printf("Validated:\n%+#v\n\n", parsedFiles) - compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64} + compileOptionsMap := make(map[string]string, 0) + for _, value := range strings.Split(*compileOptions, ",") { + kv := strings.Split(value, "=") + if len(kv) == 1 { + compileOptionsMap[kv[0]] = "" + } else { + compileOptionsMap[kv[0]] = kv[1] + } + } + + compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64, CompileOptions: compileOptionsMap} wat, err := compiler.compile() if err != nil { if c, ok := err.(CompilerError); ok { diff --git a/parser.go b/parser.go index ab54a58..8c32c5e 100644 --- a/parser.go +++ b/parser.go @@ -206,6 +206,7 @@ type ParsedFunction struct { ReturnType *Type Body *Block Locals []Local // All of the locals of the function, ordered by their index + Position TokenPosition } type Import struct { @@ -1101,7 +1102,7 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) { return nil, err } - return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body}, nil + return &ParsedFunction{Name: name, Parameters: parameters, ReturnType: returnType, Body: body, Position: tok.Position}, nil } func (p *Parser) parseFile() (*ParsedFile, error) { diff --git a/validator.go b/validator.go index b0bb577..3c000fe 100644 --- a/validator.go +++ b/validator.go @@ -25,12 +25,14 @@ var builtinFunctions map[string]*ParsedFunction = map[string]*ParsedFunction{ FullName: "builtin." + BUILTIN_MEMORY_GROW, Parameters: []ParsedParameter{{Name: "memory", Type: Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}}}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_I64, Position: unknownPosition()}, + Position: unknownPosition(), }, BUILTIN_MEMORY_SIZE: { Name: BUILTIN_MEMORY_SIZE, FullName: "builtin." + BUILTIN_MEMORY_SIZE, Parameters: []ParsedParameter{}, ReturnType: &Type{Type: Type_Primitive, Value: Primitive_U64, Position: unknownPosition()}, + Position: unknownPosition(), }, } @@ -640,7 +642,7 @@ func (v *Validator) validate() []error { function.FullName = fullFunctionName if _, exists := v.AllFunctions[fullFunctionName]; exists { - errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) + errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.Position)) } v.AllFunctions[fullFunctionName] = function