From a7007eaf0f15a7ae14e315a17822fddbe6c2fd50 Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Sat, 30 Mar 2024 21:57:38 +0100 Subject: [PATCH] Update language name, Improve implicit casts, Add array/raw memory expressions (WIP) --- .gitignore | 6 +- README.md | 2 + backend_wat.go | 173 ++++++++++---------- example/{a.lang => a.ely} | 0 example/{add.lang => add.ely} | 0 example/{b.lang => b.ely} | 0 example/{helloworld.lang => helloworld.ely} | 0 example/test.ely | 8 + example/test.lang | 5 - go.mod | 2 +- lexer.go | 3 +- main.go | 15 +- parser.go | 96 ++++++++++- stdlib/alloc.ely | 16 ++ stdlib/alloc.lang | 9 - validator.go | 144 +++++++++++----- 16 files changed, 319 insertions(+), 160 deletions(-) create mode 100644 README.md rename example/{a.lang => a.ely} (100%) rename example/{add.lang => add.ely} (100%) rename example/{b.lang => b.ely} (100%) rename example/{helloworld.lang => helloworld.ely} (100%) create mode 100644 example/test.ely delete mode 100644 example/test.lang create mode 100644 stdlib/alloc.ely delete mode 100644 stdlib/alloc.lang diff --git a/.gitignore b/.gitignore index 87da47a..26c95a9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -compiler -out.wat -build \ No newline at end of file +elysium +a.out +build diff --git a/README.md b/README.md new file mode 100644 index 0000000..7b54470 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# Elysium +The Elysium programming language. diff --git a/backend_wat.go b/backend_wat.go index 959425a..8d88293 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -7,6 +7,11 @@ import ( "unicode" ) +type Compiler struct { + Files []*ParsedFile + Wasm64 bool +} + func getPrimitiveWATType(primitive PrimitiveType) string { switch primitive { case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: @@ -24,17 +29,6 @@ func getPrimitiveWATType(primitive PrimitiveType) string { panic("unhandled type") } -func getWATType(t Type) string { - // TODO: tuples? - - if t.Type != Type_Primitive { - panic("not implemented") // TODO: non-primitive types - } - - primitive := t.Value.(PrimitiveType) - return getPrimitiveWATType(primitive) -} - func safeASCIIIdentifier(identifier string) string { ascii := "" for _, rune := range identifier { @@ -85,6 +79,23 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string { panic("invalid type") } +func (c *Compiler) getWATType(t Type) string { + switch t.Type { + case Type_Primitive: + return getPrimitiveWATType(t.Value.(PrimitiveType)) + case Type_Named, Type_Array: + if c.Wasm64 { + return "i64" + } else { + return "i32" + } + case Type_Tuple: + panic("tuple type passed to getWATType()") + } + + panic("type not implemented in getWATType()") +} + func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { if from == to { return "", nil @@ -153,28 +164,23 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { return "i32.wrap_i64\n" + getTypeCast(to), nil } -func compileExpressionWAT(expr Expression, block *Block) (string, error) { +func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, error) { var err error switch expr.Type { case Expression_Assignment: ass := expr.Value.(AssignmentExpression) - exprWAT, err := compileExpressionWAT(ass.Value, block) + exprWAT, err := c.compileExpressionWAT(ass.Value, block) if err != nil { return "", err } - cast := "" - if expr.ValueType.Type == Type_Primitive { - cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) - } - local := strconv.Itoa(block.Locals[ass.Variable].Index) getLocal := "local.get $" + local + "\n" setLocal := "local.set $" + local + "\n" - return exprWAT + cast + setLocal + getLocal, nil + return exprWAT + setLocal + getLocal, nil case Expression_Literal: lit := expr.Value.(LiteralExpression) switch lit.Literal.Type { @@ -194,6 +200,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { cast := "" if expr.ValueType.Type == Type_Primitive { + // TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) } @@ -202,26 +209,15 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { binary := expr.Value.(BinaryExpression) // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings - resultType := binary.ResultType.Value.(PrimitiveType) + operandType := binary.Left.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType) - watLeft, err := compileExpressionWAT(binary.Left, block) + watLeft, err := c.compileExpressionWAT(binary.Left, block) if err != nil { return "", err } - // TODO: cast produces unnecessary/wrong cast, make sure to upcast to target type - castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType) - if err != nil { - return "", err - } - - watRight, err := compileExpressionWAT(binary.Right, block) - if err != nil { - return "", err - } - - castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType) + watRight, err := c.compileExpressionWAT(binary.Right, block) if err != nil { return "", err } @@ -229,7 +225,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { op := "" suffix := "" - if isUnsignedInt(resultType) { + if isUnsignedInt(operandType) { suffix = "u" } else { suffix = "s" @@ -237,38 +233,38 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { switch binary.Operation { case Operation_Add: - op = getPrimitiveWATType(resultType) + ".add\n" + op = getPrimitiveWATType(operandType) + ".add\n" case Operation_Sub: - op = getPrimitiveWATType(resultType) + ".sub\n" + op = getPrimitiveWATType(operandType) + ".sub\n" case Operation_Mul: - op = getPrimitiveWATType(resultType) + ".mul\n" + op = getPrimitiveWATType(operandType) + ".mul\n" case Operation_Div: - op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n" case Operation_Mod: - op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n" case Operation_Greater: - op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n" case Operation_Less: - op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n" case Operation_GreaterEquals: - op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n" case Operation_LessEquals: - op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n" + op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n" case Operation_NotEquals: - op = getPrimitiveWATType(resultType) + ".ne\n" + op = getPrimitiveWATType(operandType) + ".ne\n" case Operation_Equals: - op = getPrimitiveWATType(resultType) + ".eq\n" + op = getPrimitiveWATType(operandType) + ".eq\n" default: panic("operation not implemented") } - return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil + return watLeft + watRight + op + getTypeCast(exprType), nil case Expression_Tuple: tuple := expr.Value.(TupleExpression) wat := "" for _, member := range tuple.Members { - memberWAT, err := compileExpressionWAT(member, block) + memberWAT, err := c.compileExpressionWAT(member, block) if err != nil { return "", err } @@ -282,7 +278,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { wat := "" if fc.Parameters != nil { - wat, err = compileExpressionWAT(*fc.Parameters, block) + wat, err = c.compileExpressionWAT(*fc.Parameters, block) if err != nil { return "", err } @@ -293,7 +289,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { neg := expr.Value.(NegateExpression) exprType := expr.ValueType.Value.(PrimitiveType) - wat, err := compileExpressionWAT(neg.Value, block) + wat, err := c.compileExpressionWAT(neg.Value, block) if err != nil { return "", err } @@ -306,29 +302,50 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) { if isFloatingPoint(exprType) { return watType + ".neg\n", nil } + case Expression_Cast: + cast := expr.Value.(CastExpression) + + wat, err := c.compileExpressionWAT(cast.Value, block) + if err != nil { + return "", err + } + + // TODO: fine, as it is currently only allowed for primitive types + fromType := cast.Value.ValueType.Value.(PrimitiveType) + toType := cast.Type.Value.(PrimitiveType) + castWAT, err := castPrimitiveWAT(fromType, toType) + if err != nil { + return "", err + } + + return wat + castWAT, nil } panic("expr not implemented") } -func compileStatementWAT(stmt Statement, block *Block) (string, error) { +func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) { switch stmt.Type { case Statement_Expression: expr := stmt.Value.(ExpressionStatement) - wat, err := compileExpressionWAT(expr.Expression, block) + wat, err := c.compileExpressionWAT(expr.Expression, block) if err != nil { return "", err } - numItems := 1 - if expr.Expression.ValueType.Type == Type_Tuple { - numItems = len(expr.Expression.ValueType.Value.(TupleType).Types) + numItems := 0 + if expr.Expression.ValueType != nil { + numItems = 1 + + if expr.Expression.ValueType.Type == Type_Tuple { + numItems = len(expr.Expression.ValueType.Value.(TupleType).Types) + } } return wat + strings.Repeat("drop\n", numItems), nil case Statement_Block: block := stmt.Value.(BlockStatement) - wat, err := compileBlockWAT(block.Block) + wat, err := c.compileBlockWAT(block.Block) if err != nil { return "", err } @@ -336,7 +353,7 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { return wat, nil case Statement_Return: ret := stmt.Value.(ReturnStatement) - wat, err := compileExpressionWAT(*ret.Value, block) + wat, err := c.compileExpressionWAT(*ret.Value, block) if err != nil { return "", err } @@ -349,30 +366,21 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { return "", nil } - wat, err := compileExpressionWAT(*dlv.Initializer, block) + wat, err := c.compileExpressionWAT(*dlv.Initializer, block) if err != nil { return "", err } - if dlv.VariableType.Type == Type_Primitive { - castWAT, err := castPrimitiveWAT(dlv.Initializer.ValueType.Value.(PrimitiveType), dlv.VariableType.Value.(PrimitiveType)) - if err != nil { - return "", err - } - - wat += castWAT - } - return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil case Statement_If: ifS := stmt.Value.(IfStatement) - conditionWAT, err := compileExpressionWAT(ifS.Condition, block) + conditionWAT, err := c.compileExpressionWAT(ifS.Condition, block) if err != nil { return "", err } - condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock) + condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock) if err != nil { return "", err } @@ -401,7 +409,7 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { if ifS.ElseBlock != nil { // condition is false - elseWAT, err := compileBlockWAT(ifS.ElseBlock) + elseWAT, err := c.compileBlockWAT(ifS.ElseBlock) if err != nil { return "", err } @@ -416,11 +424,11 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { panic("stmt not implemented") } -func compileBlockWAT(block *Block) (string, error) { +func (c *Compiler) compileBlockWAT(block *Block) (string, error) { blockWAT := "" for _, stmt := range block.Statements { - wat, err := compileStatementWAT(stmt, block) + wat, err := c.compileStatementWAT(stmt, block) if err != nil { return "", err } @@ -431,17 +439,16 @@ func compileBlockWAT(block *Block) (string, error) { return blockWAT, nil } -func compileFunctionWAT(function ParsedFunction) (string, error) { - funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + "\n" +func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) { + funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n" for _, local := range function.Locals { if !local.IsParameter { continue } - funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" + funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } - // TODO: tuples returnTypes := []Type{} if function.ReturnType != nil { returnTypes = []Type{*function.ReturnType} @@ -451,32 +458,32 @@ func compileFunctionWAT(function ParsedFunction) (string, error) { } for _, t := range returnTypes { - funcWAT += "\t(result " + getWATType(t) + ")\n" + funcWAT += "\t(result " + c.getWATType(t) + ")\n" } for _, local := range function.Locals { if local.IsParameter { continue } - funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" + funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n" } - wat, err := compileBlockWAT(function.Body) + wat, err := c.compileBlockWAT(function.Body) if err != nil { return "", err } funcWAT += wat - return funcWAT + ") (export \"" + function.FullName + "\" (func $" + safeASCIIIdentifier(function.FullName) + "))\n", nil + return funcWAT + ")\n", nil } -func backendWAT(files []*ParsedFile) (string, error) { - module := "(module (memory 1)\n" +func (c *Compiler) compile() (string, error) { + module := "(module\n" - for _, file := range files { + for _, file := range c.Files { for _, function := range file.Functions { - wat, err := compileFunctionWAT(function) + wat, err := c.compileFunctionWAT(function) if err != nil { return "", err } diff --git a/example/a.lang b/example/a.ely similarity index 100% rename from example/a.lang rename to example/a.ely diff --git a/example/add.lang b/example/add.ely similarity index 100% rename from example/add.lang rename to example/add.ely diff --git a/example/b.lang b/example/b.ely similarity index 100% rename from example/b.lang rename to example/b.ely diff --git a/example/helloworld.lang b/example/helloworld.ely similarity index 100% rename from example/helloworld.lang rename to example/helloworld.ely diff --git a/example/test.ely b/example/test.ely new file mode 100644 index 0000000..22d69bf --- /dev/null +++ b/example/test.ely @@ -0,0 +1,8 @@ +void b(u64 i) { + +} + +(u8, u16, u64) a() { + b(1u8); + return 1u8, 2u8, 3u8; +} diff --git a/example/test.lang b/example/test.lang deleted file mode 100644 index 93c20ff..0000000 --- a/example/test.lang +++ /dev/null @@ -1,5 +0,0 @@ -module sus; - -(u8, u8) a() { - return 1u8, 2u8; -} diff --git a/go.mod b/go.mod index 7e2ac80..da57197 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module cringe-studios.com/compiler +module git.cringe-studios.com/mr/elysium go 1.21.7 diff --git a/lexer.go b/lexer.go index df0da8d..e2b1778 100644 --- a/lexer.go +++ b/lexer.go @@ -22,7 +22,7 @@ const ( type Keyword uint32 -var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else"} +var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else", "raw"} const ( Keyword_Import Keyword = iota @@ -33,6 +33,7 @@ const ( KeyWord_False Keyword_If Keyword_Else + Keyword_Raw ) type Separator uint32 diff --git a/main.go b/main.go index 33072c1..e80b138 100644 --- a/main.go +++ b/main.go @@ -88,6 +88,7 @@ func readEmbedDir(name string, files map[string]string) { func main() { outputFile := flag.String("o", "a.out", "Output file") generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM") + wasm64 := flag.Bool("wasm64", false, "Use 64-bit memory (may not be supported in all browsers)") includeStdlib := flag.Bool("stdlib", true, "Include the standard library") flag.Parse() @@ -150,7 +151,7 @@ func main() { parsedFiles = append(parsedFiles, parsed) } - validator := Validator{files: parsedFiles} + validator := Validator{Files: parsedFiles} errors := validator.validate() if len(errors) != 0 { for _, err := range errors { @@ -168,7 +169,8 @@ func main() { // log.Printf("Validated:\n%+#v\n\n", parsedFiles) - wat, err := backendWAT(parsedFiles) + compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64} + wat, err := compiler.compile() if err != nil { if c, ok := err.(CompilerError); ok { printCompilerError(fileSources, c) @@ -193,13 +195,8 @@ func main() { cmd.Stdin = &input - err = cmd.Start() + output, err := cmd.CombinedOutput() if err != nil { - log.Fatalln(err) - } - - err = cmd.Wait() - if err != nil { - log.Fatalln(err) + log.Fatalln(err, string(output)) } } diff --git a/parser.go b/parser.go index 59b4d28..1f2c346 100644 --- a/parser.go +++ b/parser.go @@ -82,6 +82,9 @@ const ( Expression_Tuple Expression_FunctionCall Expression_Negate + Expression_ArrayAccess + Expression_RawMemoryReference + Expression_Cast ) type Expression struct { @@ -122,10 +125,9 @@ const ( ) type BinaryExpression struct { - Operation Operation - Left Expression - Right Expression - ResultType *Type // Type to expand the operands to before performing the operation + Operation Operation + Left Expression + Right Expression } type TupleExpression struct { @@ -141,6 +143,21 @@ type NegateExpression struct { Value Expression } +type ArrayAccessExpression struct { + Array Expression + Index Expression +} + +type RawMemoryReferenceExpression struct { + Type Type + Address Expression +} + +type CastExpression struct { + Type Type + Value Expression +} + type Local struct { Name string Type Type @@ -293,8 +310,6 @@ func (p *Parser) tryType() (*Type, error) { } if tok.Type == Type_Identifier { - // TODO: array type - var theType Type index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string)) @@ -314,11 +329,15 @@ func (p *Parser) tryType() (*Type, error) { break } - _, err = pCopy.expectSeparator(Separator_CloseSquare) + sep, err = pCopy.trySeparator(Separator_CloseSquare) if err != nil { return nil, err } + if sep == nil { + return nil, nil + } + theType = Type{Type: Type_Array, Value: ArrayType{ElementType: theType}, Position: theType.Position} } @@ -421,7 +440,7 @@ func (p *Parser) tryParanthesizedExpression() (*Expression, error) { return expr, nil } -func (p *Parser) tryUnaryExpression() (*Expression, error) { +func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) { pCopy := p.copy() token := pCopy.peekToken() @@ -457,6 +476,28 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { *p = pCopy return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: Literal{Type: Literal_Boolean, Primitive: Primitive_Bool, Value: keyword == Keyword_True}}, Position: token.Position}, nil } + + if keyword == Keyword_Raw { + pCopy.nextToken() + + rawType, err := pCopy.expectType() + if err != nil { + return nil, err + } + + _, err = pCopy.expectSeparator(Separator_Comma) + if err != nil { + return nil, err + } + + address, err := pCopy.expectExpression() + if err != nil { + return nil, err + } + + *p = pCopy + return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil + } } if token.Type == Type_Operator { @@ -513,6 +554,43 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { return nil, nil } +func (p *Parser) tryUnaryExpression() (*Expression, error) { + pCopy := p.copy() + + expr, err := pCopy.tryUnaryExpressionNoArrayAccess() // TODO: wrong precedence + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + for { + open, err := pCopy.trySeparator(Separator_OpenSquare) + if err != nil { + return nil, err + } + + if open == nil { + *p = pCopy + return expr, nil + } + + index, err := pCopy.expectExpression() + if err != nil { + return nil, err + } + + _, err = pCopy.expectSeparator(Separator_CloseSquare) + if err != nil { + return nil, err + } + + expr = &Expression{Type: Expression_ArrayAccess, Value: ArrayAccessExpression{Array: *expr, Index: *index}, Position: expr.Position} + } +} + func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) { left, err := opFunc() if err != nil { @@ -578,7 +656,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) { return nil, nil } - if lhs.Type != Expression_VariableReference { // TODO: allow other types + if lhs.Type != Expression_VariableReference { // TODO: allow other types (array access) return p.tryBinaryExpression() } diff --git a/stdlib/alloc.ely b/stdlib/alloc.ely new file mode 100644 index 0000000..6f39644 --- /dev/null +++ b/stdlib/alloc.ely @@ -0,0 +1,16 @@ +module alloc; + +u64 alloc(u64 size) { + u64 ptr = 0x0u64; + raw(i32, ptr) = 0x03u32; + i32 sus = raw(i32, ptr); + return 0u64; +} + +void free(u64 address) { + +} + +u64 growMemory(u64 numPages) { + return 0u64; +} diff --git a/stdlib/alloc.lang b/stdlib/alloc.lang deleted file mode 100644 index a352285..0000000 --- a/stdlib/alloc.lang +++ /dev/null @@ -1,9 +0,0 @@ -module alloc; - -u64 alloc(u64 size) { - return 0u64; -} - -void free(u64 address) { - -} diff --git a/validator.go b/validator.go index 8e8e041..64fcce3 100644 --- a/validator.go +++ b/validator.go @@ -6,11 +6,43 @@ import ( ) type Validator struct { - files []*ParsedFile - allFunctions map[string]*ParsedFunction + Files []*ParsedFile + AllFunctions map[string]*ParsedFunction - currentBlock *Block - currentFunction *ParsedFunction + CurrentBlock *Block + CurrentFunction *ParsedFunction +} + +func isSameType(a Type, b Type) bool { + if a.Type != b.Type { + return false + } + + switch a.Type { + case Type_Primitive: + return a.Value.(PrimitiveType) == b.Value.(PrimitiveType) + case Type_Named: + return a.Value.(NamedType).TypeName == b.Value.(NamedType).TypeName + case Type_Array: + return isSameType(a.Value.(ArrayType).ElementType, b.Value.(ArrayType).ElementType) + case Type_Tuple: + aTuple := a.Value.(TupleType) + bTuple := b.Value.(TupleType) + + if len(aTuple.Types) != len(bTuple.Types) { + return false + } + + for i := 0; i < len(aTuple.Types); i++ { + if !isSameType(aTuple.Types[i], bTuple.Types[i]) { + return false + } + } + + return true + } + + panic("type not implemented") } func isTypeExpandableTo(from Type, to Type) bool { @@ -40,9 +72,35 @@ func isTypeExpandableTo(from Type, to Type) bool { return true } + if from.Type == Type_Array { + return isSameType(from.Value.(ArrayType).ElementType, to.Value.(ArrayType).ElementType) + } + panic("not implemented") } +func expandExpressionToType(expr *Expression, to Type) { + // TODO: merge with isTypeExpandableTo? + + if isSameType(*expr.ValueType, to) { + return + } + + if expr.Type == Expression_Tuple { + tupleExpr := expr.Value.(TupleExpression) + tupleType := to.Value.(TupleType) + + for i := 0; i < len(tupleType.Types); i++ { + expandExpressionToType(&tupleExpr.Members[i], tupleType.Types[i]) + } + + expr.Value = tupleExpr + return + } + + *expr = Expression{Type: Expression_Cast, Value: CastExpression{Type: to, Value: *expr}, ValueType: &to, Position: expr.Position} +} + func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { if from == to { return true @@ -118,7 +176,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error switch expr.Type { case Expression_Assignment: assignment := expr.Value.(AssignmentExpression) - local := getLocal(v.currentBlock, assignment.Variable) + local := getLocal(v.CurrentBlock, assignment.Variable) if local == nil { errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) return errors @@ -130,8 +188,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { - errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position)) + if !isSameType(*assignment.Value.ValueType, local.Type) { + if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { + errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position)) + } + + expandExpressionToType(&assignment.Value, local.Type) } expr.ValueType = &local.Type @@ -147,7 +209,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error } case Expression_VariableReference: reference := expr.Value.(VariableReferenceExpression) - local := getLocal(v.currentBlock, reference.Variable) + local := getLocal(v.CurrentBlock, reference.Variable) if local == nil { errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) return errors @@ -170,24 +232,19 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - leftType := binary.Left.ValueType.Value.(PrimitiveType) - rightType := binary.Right.ValueType.Value.(PrimitiveType) - - var result PrimitiveType = InvalidValue - if isPrimitiveTypeExpandableTo(leftType, rightType) { - result = leftType - } - - if isPrimitiveTypeExpandableTo(rightType, leftType) { - result = leftType - } - - if result == InvalidValue { - errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", leftType, rightType), expr.Position)) + var operandType Type + if isTypeExpandableTo(*binary.Left.ValueType, *binary.Right.ValueType) { + operandType = *binary.Right.ValueType + } else if isTypeExpandableTo(*binary.Right.ValueType, *binary.Left.ValueType) { + operandType = *binary.Left.ValueType + } else { + errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", binary.Left.ValueType.Value.(PrimitiveType), binary.Right.ValueType.Value.(PrimitiveType)), expr.Position)) return errors } - binary.ResultType = &Type{Type: Type_Primitive, Value: result} + expandExpressionToType(&binary.Left, operandType) + expandExpressionToType(&binary.Right, operandType) + expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool} } @@ -205,7 +262,6 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - binary.ResultType = &Type{Type: Type_Primitive, Value: result} expr.ValueType = &Type{Type: Type_Primitive, Value: result} } @@ -235,7 +291,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) - calledFunc, ok := v.allFunctions[fc.Function] + calledFunc, ok := v.AllFunctions[fc.Function] if !ok { errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors @@ -248,9 +304,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error return errors } - params := []Expression{*fc.Parameters} + params := []*Expression{fc.Parameters} if fc.Parameters.Type == Expression_Tuple { - params = fc.Parameters.Value.(TupleExpression).Members + for i := 0; i < len(fc.Parameters.Value.(TupleExpression).Members); i++ { + params[i] = &fc.Parameters.Value.(TupleExpression).Members[i] + } } if len(params) != len(calledFunc.Parameters) { @@ -263,6 +321,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position)) } + + expandExpressionToType(typeGiven, typeExpected.Type) } } @@ -319,7 +379,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) case Statement_Return: ret := stmt.Value.(ReturnStatement) if ret.Value != nil { - if v.currentFunction.ReturnType == nil { + if v.CurrentFunction.ReturnType == nil { errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) return errors } @@ -330,10 +390,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) return errors } - if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) { - errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.currentFunction.ReturnType), ret.Value.Position)) + if !isTypeExpandableTo(*ret.Value.ValueType, *v.CurrentFunction.ReturnType) { + errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.CurrentFunction.ReturnType), ret.Value.Position)) } - } else if v.currentFunction.ReturnType != nil { + + expandExpressionToType(ret.Value, *v.CurrentFunction.ReturnType) + } else if v.CurrentFunction.ReturnType != nil { errors = append(errors, v.createError("missing return value", stmt.Position)) } @@ -349,14 +411,16 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local) if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) { errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *dlv.Initializer.ValueType, dlv.VariableType), stmt.Position)) } + + expandExpressionToType(dlv.Initializer, dlv.VariableType) } - if getLocal(v.currentBlock, dlv.Variable) != nil { + if getLocal(v.CurrentBlock, dlv.Variable) != nil { errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) } local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} - v.currentBlock.Locals[dlv.Variable] = local + v.CurrentBlock.Locals[dlv.Variable] = local *functionLocals = append(*functionLocals, local) stmt.Value = dlv @@ -394,7 +458,7 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error } for i := range block.Statements { - v.currentBlock = block + v.CurrentBlock = block stmt := &block.Statements[i] errors = append(errors, v.validateStatement(stmt, functionLocals)...) } @@ -407,7 +471,7 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { var locals []Local - v.currentFunction = function + v.CurrentFunction = function body := function.Body body.Locals = make(map[string]Local) @@ -429,8 +493,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { func (v *Validator) validate() []error { var errors []error - v.allFunctions = make(map[string]*ParsedFunction) - for _, file := range v.files { + v.AllFunctions = make(map[string]*ParsedFunction) + for _, file := range v.Files { for i := range file.Functions { function := &file.Functions[i] @@ -441,15 +505,15 @@ func (v *Validator) validate() []error { function.FullName = fullFunctionName - if _, exists := v.allFunctions[fullFunctionName]; exists { + if _, exists := v.AllFunctions[fullFunctionName]; exists { errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) } - v.allFunctions[fullFunctionName] = function + v.AllFunctions[fullFunctionName] = function } } - for _, file := range v.files { + for _, file := range v.Files { for i := range file.Imports { errors = append(errors, v.validateImport(&file.Imports[i])...) }