diff --git a/backend_wat.go b/backend_wat.go index 40f04ba..959425a 100644 --- a/backend_wat.go +++ b/backend_wat.go @@ -3,6 +3,7 @@ package main import ( "errors" "strconv" + "strings" "unicode" ) @@ -319,7 +320,12 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) { return "", err } - return wat + "drop\n", nil + numItems := 1 + if expr.Expression.ValueType.Type == Type_Tuple { + numItems = len(expr.Expression.ValueType.Value.(TupleType).Types) + } + + return wat + strings.Repeat("drop\n", numItems), nil case Statement_Block: block := stmt.Value.(BlockStatement) wat, err := compileBlockWAT(block.Block) @@ -426,7 +432,7 @@ func compileBlockWAT(block *Block) (string, error) { } func compileFunctionWAT(function ParsedFunction) (string, error) { - funcWAT := "(func $" + safeASCIIIdentifier(function.Name) + "\n" + funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + "\n" for _, local := range function.Locals { if !local.IsParameter { @@ -462,19 +468,21 @@ func compileFunctionWAT(function ParsedFunction) (string, error) { funcWAT += wat - return funcWAT + ") (export \"" + function.Name + "\" (func $" + safeASCIIIdentifier(function.Name) + "))\n", nil + return funcWAT + ") (export \"" + function.FullName + "\" (func $" + safeASCIIIdentifier(function.FullName) + "))\n", nil } -func backendWAT(file ParsedFile) (string, error) { +func backendWAT(files []*ParsedFile) (string, error) { module := "(module (memory 1)\n" - for _, function := range file.Functions { - wat, err := compileFunctionWAT(function) - if err != nil { - return "", err - } + for _, file := range files { + for _, function := range file.Functions { + wat, err := compileFunctionWAT(function) + if err != nil { + return "", err + } - module += wat + module += wat + } } module += ")" diff --git a/lexer.go b/lexer.go index 4b80dd2..df0da8d 100644 --- a/lexer.go +++ b/lexer.go @@ -22,10 +22,11 @@ const ( type Keyword uint32 -var Keywords []string = []string{"import", "void", "return", "true", "false", "if", "else"} +var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else"} const ( Keyword_Import Keyword = iota + Keyword_Module Keyword_Void Keyword_Return Keyword_True @@ -36,7 +37,7 @@ const ( type Separator uint32 -var Separators []rune = []rune{'(', ')', '{', '}', '[', ']', ';', ','} +var Separators []rune = []rune{'(', ')', '{', '}', '[', ']', ';', ',', '.'} const ( Separator_OpenParen Separator = iota @@ -47,6 +48,7 @@ const ( Separator_CloseSquare Separator_Semicolon Separator_Comma + Separator_Dot ) type Operator uint32 @@ -84,10 +86,15 @@ const ( type LexToken struct { Type LexType - Position uint64 + Position TokenPosition Value any } +type TokenPosition struct { + SourceFile string + Position uint64 +} + type Literal struct { Type LiteralType Primitive PrimitiveType @@ -97,11 +104,12 @@ type Literal struct { type Lexer struct { Runes []rune LastTokenPosition uint64 + SourceFile string Position uint64 } func (l *Lexer) error(message string) error { - return CompilerError{Position: l.LastTokenPosition, Message: message} + return CompilerError{Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Message: message} } func (l *Lexer) peekRune() *rune { @@ -210,12 +218,12 @@ func (l *Lexer) nextToken() (*LexToken, error) { return nil, err } - return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: literal}}, nil + return &LexToken{Type: Type_Literal, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: Literal{Type: Literal_String, Primitive: InvalidValue, Value: literal}}, nil } op := l.tryOperator() if op != InvalidValue { - return &LexToken{Type: Type_Operator, Position: l.LastTokenPosition, Value: op}, nil + return &LexToken{Type: Type_Operator, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: op}, nil } token := "" @@ -238,8 +246,6 @@ func (l *Lexer) nextToken() (*LexToken, error) { runes := []rune(token) if unicode.IsDigit([]rune(token)[0]) { - // TODO: hexadecimal/binary/octal constants - var numberType PrimitiveType = InvalidValue var rawNumber string = token for i, name := range PRIMITIVE_TYPE_NAMES { @@ -268,20 +274,20 @@ func (l *Lexer) nextToken() (*LexToken, error) { return nil, err } - return &LexToken{Type: Type_Literal, Position: l.LastTokenPosition, Value: Literal{Type: Literal_Number, Primitive: numberType, Value: number}}, nil + return &LexToken{Type: Type_Literal, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: Literal{Type: Literal_Number, Primitive: numberType, Value: number}}, nil } if len(runes) == 1 { if idx := slices.Index(Separators, runes[0]); idx != -1 { - return &LexToken{Type: Type_Separator, Position: l.LastTokenPosition, Value: Separator(idx)}, nil + return &LexToken{Type: Type_Separator, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: Separator(idx)}, nil } } if idx := slices.Index(Keywords, token); idx != -1 { - return &LexToken{Type: Type_Keyword, Position: l.LastTokenPosition, Value: Keyword(idx)}, nil + return &LexToken{Type: Type_Keyword, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: Keyword(idx)}, nil } - return &LexToken{Type: Type_Identifier, Position: l.LastTokenPosition, Value: token}, nil + return &LexToken{Type: Type_Identifier, Position: TokenPosition{SourceFile: l.SourceFile, Position: l.LastTokenPosition}, Value: token}, nil } func (l *Lexer) parseNumber(raw string, numberType PrimitiveType) (any, error) { @@ -335,10 +341,10 @@ func (l *Lexer) parseNumber(raw string, numberType PrimitiveType) (any, error) { panic(fmt.Sprintf("Unhandled type %s in parseNumber()", numberType)) } -func lexer(program string) ([]LexToken, error) { +func lexer(sourceFile string, source string) ([]LexToken, error) { var tokens []LexToken - lexer := Lexer{Runes: []rune(program)} + lexer := Lexer{SourceFile: sourceFile, Runes: []rune(source)} for { token, err := lexer.nextToken() diff --git a/main.go b/main.go index 6ea837c..c6d5784 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( + "embed" "log" "os" "strconv" @@ -10,6 +11,9 @@ import ( const ERROR_LOG_LINES = 5 +//go:embed stdlib/* +var stdlib embed.FS + func countTabs(line string) int { tabs := 0 for _, rune := range line { @@ -20,13 +24,19 @@ func countTabs(line string) int { return tabs } -func printCompilerError(file string, source string, err CompilerError) { +func printCompilerError(sources map[string]string, err CompilerError) { + source, ok := sources[err.Position.SourceFile] + if !ok { + log.Println(err) + return + } + sourceRunes := []rune(source) lines := strings.Split(source, "\n") line := 0 col := 0 var i uint64 - for i = 0; i < err.Position; i++ { + for i = 0; i < err.Position.Position; i++ { col++ if sourceRunes[i] == '\n' { line++ @@ -34,7 +44,7 @@ func printCompilerError(file string, source string, err CompilerError) { } } - log.Println("Failed to compile: " + err.Message + " (at " + file + ":" + strconv.Itoa(line+1) + ":" + strconv.Itoa(col+1) + ")") + log.Println("Failed to compile: " + err.Message + " (at " + err.Position.SourceFile + ":" + strconv.Itoa(line+1) + ":" + strconv.Itoa(col+1) + ")") linesStart := max(0, line-ERROR_LOG_LINES) linesEnd := min(len(lines), line+ERROR_LOG_LINES+1) @@ -51,52 +61,91 @@ func printCompilerError(file string, source string, err CompilerError) { } } +func readEmbedDir(name string, files map[string]string) { + entries, err := stdlib.ReadDir(name) + if err != nil { + log.Fatalln(err) + } + + for _, entry := range entries { + fullName := name + "/" + entry.Name() + if entry.IsDir() { + readEmbedDir(fullName, files) + } else { + bytes, err := stdlib.ReadFile(fullName) + if err != nil { + log.Fatalln(err) + } + + files[fullName] = string(bytes) + } + } +} + func main() { - if len(os.Args) != 2 { - log.Fatalln("Usage: " + os.Args[0] + " ") + if len(os.Args) < 2 { + log.Fatalln("Usage: " + os.Args[0] + " ") } - file := os.Args[1] - content, err := os.ReadFile(file) - if err != nil { - log.Fatalln("Cannot open input file.", err) - } + files := os.Args[1:] - source := string(content) - - tokens, err := lexer(source) - if err != nil { - if c, ok := err.(CompilerError); ok { - printCompilerError(file, source, c) - } else { - log.Println(err) + fileSources := make(map[string]string) + for _, file := range files { + content, err := os.ReadFile(file) + if err != nil { + log.Fatalln("Cannot open input file.", err) } - return + fileSources[file] = string(content) } - log.Printf("Tokens:\n%+#v\n\n", tokens) + stdlibFiles := make(map[string]string) + readEmbedDir("stdlib", stdlibFiles) + for path, file := range stdlibFiles { + fileSources["[embedded]/"+path] = file + } - parser := Parser{Tokens: tokens} - parsed, err := parser.parseFile() - if err != nil { - if c, ok := err.(CompilerError); ok { - printCompilerError(file, source, c) - } else { - log.Println(err) + fileTokens := make(map[string][]LexToken) + for file, source := range fileSources { + tokens, err := lexer(file, source) + if err != nil { + if c, ok := err.(CompilerError); ok { + printCompilerError(fileSources, c) + } else { + log.Println(err) + } + + return } - return + // log.Printf("Tokens:\n%+#v\n\n", tokens) + fileTokens[file] = tokens } - log.Printf("Parsed:\n%+#v\n\n", parsed) + var parsedFiles []*ParsedFile + for _, tokens := range fileTokens { + parser := Parser{Tokens: tokens} + parsed, err := parser.parseFile() + if err != nil { + if c, ok := err.(CompilerError); ok { + printCompilerError(fileSources, c) + } else { + log.Println(err) + } - validator := Validator{file: parsed} + return + } + + log.Printf("Parsed:\n%+#v\n\n", parsed) + parsedFiles = append(parsedFiles, parsed) + } + + validator := Validator{files: parsedFiles} errors := validator.validate() if len(errors) != 0 { - for _, err = range errors { + for _, err := range errors { if c, ok := err.(CompilerError); ok { - printCompilerError(file, source, c) + printCompilerError(fileSources, c) } else { log.Println(err) } @@ -107,12 +156,12 @@ func main() { } } - log.Printf("Validated:\n%+#v\n\n", parsed) + // log.Printf("Validated:\n%+#v\n\n", parsedFiles) - wat, err := backendWAT(*parsed) + wat, err := backendWAT(parsedFiles) if err != nil { if c, ok := err.(CompilerError); ok { - printCompilerError(file, source, c) + printCompilerError(fileSources, c) } else { log.Println(err) } diff --git a/parser.go b/parser.go index a245302..59b4d28 100644 --- a/parser.go +++ b/parser.go @@ -17,7 +17,7 @@ const ( type Type struct { Type TypeType Value any - Position uint64 + Position TokenPosition } type NamedType struct { @@ -45,7 +45,7 @@ const ( type Statement struct { Type StatementType Value any - Position uint64 + Position TokenPosition } type ExpressionStatement struct { @@ -88,7 +88,7 @@ type Expression struct { Type ExpressionType Value any ValueType *Type - Position uint64 + Position TokenPosition } type AssignmentExpression struct { @@ -161,6 +161,7 @@ type ParsedParameter struct { type ParsedFunction struct { Name string + FullName string // The fully-qualified name of the function, including the module name Parameters []ParsedParameter ReturnType *Type Body *Block @@ -172,13 +173,14 @@ type Import struct { } type ParsedFile struct { + Module string Imports []Import Functions []ParsedFunction } type Parser struct { Tokens []LexToken - Position uint64 + Position TokenPosition } func (p Parser) copy() Parser { @@ -481,6 +483,8 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) { if token.Type == Type_Identifier { pCopy.nextToken() + // TODO: possible module name + next, err := pCopy.trySeparator(Separator_OpenParen) if err != nil { return nil, err @@ -921,8 +925,9 @@ func (p *Parser) expectFunction() (*ParsedFunction, error) { func (p *Parser) parseFile() (*ParsedFile, error) { var err error - var functions []ParsedFunction + var module string var imports []Import + var functions []ParsedFunction for { token := p.peekToken() @@ -941,6 +946,24 @@ func (p *Parser) parseFile() (*ParsedFile, error) { continue } + if token.Type == Type_Keyword && token.Value.(Keyword) == Keyword_Module { + p.nextToken() + + if module != "" { + return nil, p.error("duplicate module declaration") + } + + module, err = p.expectIdentifier() + if err != nil { + return nil, err + } + + _, err := p.expectSeparator(Separator_Semicolon) + if err != nil { + return nil, err + } + } + var parsedFunction *ParsedFunction parsedFunction, err = p.expectFunction() if err != nil { @@ -950,5 +973,5 @@ func (p *Parser) parseFile() (*ParsedFile, error) { functions = append(functions, *parsedFunction) } - return &ParsedFile{Imports: imports, Functions: functions}, nil + return &ParsedFile{Module: module, Imports: imports, Functions: functions}, nil } diff --git a/types.go b/types.go index c83d5f8..484b9fa 100644 --- a/types.go +++ b/types.go @@ -43,12 +43,12 @@ var STRING_TYPE = Type{Type: Type_Named, Value: STRING_TYPE_NAME} const InvalidValue = 0xEEEEEE // Magic value type CompilerError struct { - Position uint64 + Position TokenPosition Message string } func (e CompilerError) Error() string { - return e.Message + " (at " + strconv.FormatUint(e.Position, 10) + ")" + return e.Message + " (at " + e.Position.SourceFile + ", index " + strconv.FormatUint(e.Position.Position, 10) + ")" } func isSignedInt(primitiveType PrimitiveType) bool { diff --git a/validator.go b/validator.go index 75873ea..8e8e041 100644 --- a/validator.go +++ b/validator.go @@ -6,7 +6,8 @@ import ( ) type Validator struct { - file *ParsedFile + files []*ParsedFile + allFunctions map[string]*ParsedFunction currentBlock *Block currentFunction *ParsedFunction @@ -74,7 +75,7 @@ func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { return false } -func (v *Validator) createError(message string, position uint64) error { +func (v *Validator) createError(message string, position TokenPosition) error { return CompilerError{Position: position, Message: message} } @@ -234,15 +235,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error case Expression_FunctionCall: fc := expr.Value.(FunctionCallExpression) - var calledFunc *ParsedFunction = nil - for _, f := range v.file.Functions { - if f.Name == fc.Function { - calledFunc = &f - break - } - } - - if calledFunc == nil { + calledFunc, ok := v.allFunctions[fc.Function] + if !ok { errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) return errors } @@ -435,12 +429,34 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error { func (v *Validator) validate() []error { var errors []error - for i := range v.file.Imports { - errors = append(errors, v.validateImport(&v.file.Imports[i])...) + v.allFunctions = make(map[string]*ParsedFunction) + for _, file := range v.files { + for i := range file.Functions { + function := &file.Functions[i] + + fullFunctionName := function.Name + if file.Module != "" { + fullFunctionName = file.Module + "." + fullFunctionName + } + + function.FullName = fullFunctionName + + if _, exists := v.allFunctions[fullFunctionName]; exists { + errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) + } + + v.allFunctions[fullFunctionName] = function + } } - for i := range v.file.Functions { - errors = append(errors, v.validateFunction(&v.file.Functions[i])...) + for _, file := range v.files { + for i := range file.Imports { + errors = append(errors, v.validateImport(&file.Imports[i])...) + } + + for i := range file.Functions { + errors = append(errors, v.validateFunction(&file.Functions[i])...) + } } return errors