Update language name, Improve implicit casts, Add array/raw memory expressions (WIP)

This commit is contained in:
MrLetsplay 2024-03-30 21:57:38 +01:00
parent 7d38efd106
commit a7007eaf0f
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
16 changed files with 319 additions and 160 deletions

6
.gitignore vendored
View File

@ -1,3 +1,3 @@
compiler elysium
out.wat a.out
build build

2
README.md Normal file
View File

@ -0,0 +1,2 @@
# Elysium
The Elysium programming language.

View File

@ -7,6 +7,11 @@ import (
"unicode" "unicode"
) )
type Compiler struct {
Files []*ParsedFile
Wasm64 bool
}
func getPrimitiveWATType(primitive PrimitiveType) string { func getPrimitiveWATType(primitive PrimitiveType) string {
switch primitive { switch primitive {
case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32: case Primitive_I8, Primitive_I16, Primitive_I32, Primitive_U8, Primitive_U16, Primitive_U32:
@ -24,17 +29,6 @@ func getPrimitiveWATType(primitive PrimitiveType) string {
panic("unhandled type") panic("unhandled type")
} }
func getWATType(t Type) string {
// TODO: tuples?
if t.Type != Type_Primitive {
panic("not implemented") // TODO: non-primitive types
}
primitive := t.Value.(PrimitiveType)
return getPrimitiveWATType(primitive)
}
func safeASCIIIdentifier(identifier string) string { func safeASCIIIdentifier(identifier string) string {
ascii := "" ascii := ""
for _, rune := range identifier { for _, rune := range identifier {
@ -85,6 +79,23 @@ func pushConstantNumberWAT(primitive PrimitiveType, value any) string {
panic("invalid type") panic("invalid type")
} }
func (c *Compiler) getWATType(t Type) string {
switch t.Type {
case Type_Primitive:
return getPrimitiveWATType(t.Value.(PrimitiveType))
case Type_Named, Type_Array:
if c.Wasm64 {
return "i64"
} else {
return "i32"
}
case Type_Tuple:
panic("tuple type passed to getWATType()")
}
panic("type not implemented in getWATType()")
}
func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) { func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
if from == to { if from == to {
return "", nil return "", nil
@ -153,28 +164,23 @@ func castPrimitiveWAT(from PrimitiveType, to PrimitiveType) (string, error) {
return "i32.wrap_i64\n" + getTypeCast(to), nil return "i32.wrap_i64\n" + getTypeCast(to), nil
} }
func compileExpressionWAT(expr Expression, block *Block) (string, error) { func (c *Compiler) compileExpressionWAT(expr Expression, block *Block) (string, error) {
var err error var err error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
ass := expr.Value.(AssignmentExpression) ass := expr.Value.(AssignmentExpression)
exprWAT, err := compileExpressionWAT(ass.Value, block) exprWAT, err := c.compileExpressionWAT(ass.Value, block)
if err != nil { if err != nil {
return "", err return "", err
} }
cast := ""
if expr.ValueType.Type == Type_Primitive {
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
}
local := strconv.Itoa(block.Locals[ass.Variable].Index) local := strconv.Itoa(block.Locals[ass.Variable].Index)
getLocal := "local.get $" + local + "\n" getLocal := "local.get $" + local + "\n"
setLocal := "local.set $" + local + "\n" setLocal := "local.set $" + local + "\n"
return exprWAT + cast + setLocal + getLocal, nil return exprWAT + setLocal + getLocal, nil
case Expression_Literal: case Expression_Literal:
lit := expr.Value.(LiteralExpression) lit := expr.Value.(LiteralExpression)
switch lit.Literal.Type { switch lit.Literal.Type {
@ -194,6 +200,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
cast := "" cast := ""
if expr.ValueType.Type == Type_Primitive { if expr.ValueType.Type == Type_Primitive {
// TODO: technically only needed for function parameters because functions can be called from outside WASM so they might not be fully type checked
cast = getTypeCast(expr.ValueType.Value.(PrimitiveType)) cast = getTypeCast(expr.ValueType.Value.(PrimitiveType))
} }
@ -202,26 +209,15 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
binary := expr.Value.(BinaryExpression) binary := expr.Value.(BinaryExpression)
// TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings // TODO: currently fine, only allowed for primitive types, but should be expanded to allow e.g. strings
resultType := binary.ResultType.Value.(PrimitiveType) operandType := binary.Left.ValueType.Value.(PrimitiveType)
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
watLeft, err := compileExpressionWAT(binary.Left, block) watLeft, err := c.compileExpressionWAT(binary.Left, block)
if err != nil { if err != nil {
return "", err return "", err
} }
// TODO: cast produces unnecessary/wrong cast, make sure to upcast to target type watRight, err := c.compileExpressionWAT(binary.Right, block)
castLeft, err := castPrimitiveWAT(binary.Left.ValueType.Value.(PrimitiveType), resultType)
if err != nil {
return "", err
}
watRight, err := compileExpressionWAT(binary.Right, block)
if err != nil {
return "", err
}
castRight, err := castPrimitiveWAT(binary.Right.ValueType.Value.(PrimitiveType), resultType)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -229,7 +225,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
op := "" op := ""
suffix := "" suffix := ""
if isUnsignedInt(resultType) { if isUnsignedInt(operandType) {
suffix = "u" suffix = "u"
} else { } else {
suffix = "s" suffix = "s"
@ -237,38 +233,38 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
switch binary.Operation { switch binary.Operation {
case Operation_Add: case Operation_Add:
op = getPrimitiveWATType(resultType) + ".add\n" op = getPrimitiveWATType(operandType) + ".add\n"
case Operation_Sub: case Operation_Sub:
op = getPrimitiveWATType(resultType) + ".sub\n" op = getPrimitiveWATType(operandType) + ".sub\n"
case Operation_Mul: case Operation_Mul:
op = getPrimitiveWATType(resultType) + ".mul\n" op = getPrimitiveWATType(operandType) + ".mul\n"
case Operation_Div: case Operation_Div:
op = getPrimitiveWATType(resultType) + ".div_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".div_" + suffix + "\n"
case Operation_Mod: case Operation_Mod:
op = getPrimitiveWATType(resultType) + ".rem_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".rem_" + suffix + "\n"
case Operation_Greater: case Operation_Greater:
op = getPrimitiveWATType(resultType) + ".gt_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".gt_" + suffix + "\n"
case Operation_Less: case Operation_Less:
op = getPrimitiveWATType(resultType) + ".lt_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".lt_" + suffix + "\n"
case Operation_GreaterEquals: case Operation_GreaterEquals:
op = getPrimitiveWATType(resultType) + ".ge_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".ge_" + suffix + "\n"
case Operation_LessEquals: case Operation_LessEquals:
op = getPrimitiveWATType(resultType) + ".le_" + suffix + "\n" op = getPrimitiveWATType(operandType) + ".le_" + suffix + "\n"
case Operation_NotEquals: case Operation_NotEquals:
op = getPrimitiveWATType(resultType) + ".ne\n" op = getPrimitiveWATType(operandType) + ".ne\n"
case Operation_Equals: case Operation_Equals:
op = getPrimitiveWATType(resultType) + ".eq\n" op = getPrimitiveWATType(operandType) + ".eq\n"
default: default:
panic("operation not implemented") panic("operation not implemented")
} }
return watLeft + castLeft + watRight + castRight + op + getTypeCast(exprType), nil return watLeft + watRight + op + getTypeCast(exprType), nil
case Expression_Tuple: case Expression_Tuple:
tuple := expr.Value.(TupleExpression) tuple := expr.Value.(TupleExpression)
wat := "" wat := ""
for _, member := range tuple.Members { for _, member := range tuple.Members {
memberWAT, err := compileExpressionWAT(member, block) memberWAT, err := c.compileExpressionWAT(member, block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -282,7 +278,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
wat := "" wat := ""
if fc.Parameters != nil { if fc.Parameters != nil {
wat, err = compileExpressionWAT(*fc.Parameters, block) wat, err = c.compileExpressionWAT(*fc.Parameters, block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -293,7 +289,7 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
neg := expr.Value.(NegateExpression) neg := expr.Value.(NegateExpression)
exprType := expr.ValueType.Value.(PrimitiveType) exprType := expr.ValueType.Value.(PrimitiveType)
wat, err := compileExpressionWAT(neg.Value, block) wat, err := c.compileExpressionWAT(neg.Value, block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -306,29 +302,50 @@ func compileExpressionWAT(expr Expression, block *Block) (string, error) {
if isFloatingPoint(exprType) { if isFloatingPoint(exprType) {
return watType + ".neg\n", nil return watType + ".neg\n", nil
} }
case Expression_Cast:
cast := expr.Value.(CastExpression)
wat, err := c.compileExpressionWAT(cast.Value, block)
if err != nil {
return "", err
}
// TODO: fine, as it is currently only allowed for primitive types
fromType := cast.Value.ValueType.Value.(PrimitiveType)
toType := cast.Type.Value.(PrimitiveType)
castWAT, err := castPrimitiveWAT(fromType, toType)
if err != nil {
return "", err
}
return wat + castWAT, nil
} }
panic("expr not implemented") panic("expr not implemented")
} }
func compileStatementWAT(stmt Statement, block *Block) (string, error) { func (c *Compiler) compileStatementWAT(stmt Statement, block *Block) (string, error) {
switch stmt.Type { switch stmt.Type {
case Statement_Expression: case Statement_Expression:
expr := stmt.Value.(ExpressionStatement) expr := stmt.Value.(ExpressionStatement)
wat, err := compileExpressionWAT(expr.Expression, block) wat, err := c.compileExpressionWAT(expr.Expression, block)
if err != nil { if err != nil {
return "", err return "", err
} }
numItems := 1 numItems := 0
if expr.Expression.ValueType.Type == Type_Tuple { if expr.Expression.ValueType != nil {
numItems = len(expr.Expression.ValueType.Value.(TupleType).Types) numItems = 1
if expr.Expression.ValueType.Type == Type_Tuple {
numItems = len(expr.Expression.ValueType.Value.(TupleType).Types)
}
} }
return wat + strings.Repeat("drop\n", numItems), nil return wat + strings.Repeat("drop\n", numItems), nil
case Statement_Block: case Statement_Block:
block := stmt.Value.(BlockStatement) block := stmt.Value.(BlockStatement)
wat, err := compileBlockWAT(block.Block) wat, err := c.compileBlockWAT(block.Block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -336,7 +353,7 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) {
return wat, nil return wat, nil
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
wat, err := compileExpressionWAT(*ret.Value, block) wat, err := c.compileExpressionWAT(*ret.Value, block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -349,30 +366,21 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) {
return "", nil return "", nil
} }
wat, err := compileExpressionWAT(*dlv.Initializer, block) wat, err := c.compileExpressionWAT(*dlv.Initializer, block)
if err != nil { if err != nil {
return "", err return "", err
} }
if dlv.VariableType.Type == Type_Primitive {
castWAT, err := castPrimitiveWAT(dlv.Initializer.ValueType.Value.(PrimitiveType), dlv.VariableType.Value.(PrimitiveType))
if err != nil {
return "", err
}
wat += castWAT
}
return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil return wat + "local.set $" + strconv.Itoa(block.Locals[dlv.Variable].Index) + "\n", nil
case Statement_If: case Statement_If:
ifS := stmt.Value.(IfStatement) ifS := stmt.Value.(IfStatement)
conditionWAT, err := compileExpressionWAT(ifS.Condition, block) conditionWAT, err := c.compileExpressionWAT(ifS.Condition, block)
if err != nil { if err != nil {
return "", err return "", err
} }
condBlockWAT, err := compileBlockWAT(ifS.ConditionalBlock) condBlockWAT, err := c.compileBlockWAT(ifS.ConditionalBlock)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -401,7 +409,7 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) {
if ifS.ElseBlock != nil { if ifS.ElseBlock != nil {
// condition is false // condition is false
elseWAT, err := compileBlockWAT(ifS.ElseBlock) elseWAT, err := c.compileBlockWAT(ifS.ElseBlock)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -416,11 +424,11 @@ func compileStatementWAT(stmt Statement, block *Block) (string, error) {
panic("stmt not implemented") panic("stmt not implemented")
} }
func compileBlockWAT(block *Block) (string, error) { func (c *Compiler) compileBlockWAT(block *Block) (string, error) {
blockWAT := "" blockWAT := ""
for _, stmt := range block.Statements { for _, stmt := range block.Statements {
wat, err := compileStatementWAT(stmt, block) wat, err := c.compileStatementWAT(stmt, block)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -431,17 +439,16 @@ func compileBlockWAT(block *Block) (string, error) {
return blockWAT, nil return blockWAT, nil
} }
func compileFunctionWAT(function ParsedFunction) (string, error) { func (c *Compiler) compileFunctionWAT(function ParsedFunction) (string, error) {
funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + "\n" funcWAT := "(func $" + safeASCIIIdentifier(function.FullName) + " (export \"" + function.FullName + "\")\n"
for _, local := range function.Locals { for _, local := range function.Locals {
if !local.IsParameter { if !local.IsParameter {
continue continue
} }
funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" funcWAT += "\t(param $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
} }
// TODO: tuples
returnTypes := []Type{} returnTypes := []Type{}
if function.ReturnType != nil { if function.ReturnType != nil {
returnTypes = []Type{*function.ReturnType} returnTypes = []Type{*function.ReturnType}
@ -451,32 +458,32 @@ func compileFunctionWAT(function ParsedFunction) (string, error) {
} }
for _, t := range returnTypes { for _, t := range returnTypes {
funcWAT += "\t(result " + getWATType(t) + ")\n" funcWAT += "\t(result " + c.getWATType(t) + ")\n"
} }
for _, local := range function.Locals { for _, local := range function.Locals {
if local.IsParameter { if local.IsParameter {
continue continue
} }
funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + getWATType(local.Type) + ")\n" funcWAT += "\t(local $" + strconv.Itoa(local.Index) + " " + c.getWATType(local.Type) + ")\n"
} }
wat, err := compileBlockWAT(function.Body) wat, err := c.compileBlockWAT(function.Body)
if err != nil { if err != nil {
return "", err return "", err
} }
funcWAT += wat funcWAT += wat
return funcWAT + ") (export \"" + function.FullName + "\" (func $" + safeASCIIIdentifier(function.FullName) + "))\n", nil return funcWAT + ")\n", nil
} }
func backendWAT(files []*ParsedFile) (string, error) { func (c *Compiler) compile() (string, error) {
module := "(module (memory 1)\n" module := "(module\n"
for _, file := range files { for _, file := range c.Files {
for _, function := range file.Functions { for _, function := range file.Functions {
wat, err := compileFunctionWAT(function) wat, err := c.compileFunctionWAT(function)
if err != nil { if err != nil {
return "", err return "", err
} }

8
example/test.ely Normal file
View File

@ -0,0 +1,8 @@
void b(u64 i) {
}
(u8, u16, u64) a() {
b(1u8);
return 1u8, 2u8, 3u8;
}

View File

@ -1,5 +0,0 @@
module sus;
(u8, u8) a() {
return 1u8, 2u8;
}

2
go.mod
View File

@ -1,3 +1,3 @@
module cringe-studios.com/compiler module git.cringe-studios.com/mr/elysium
go 1.21.7 go 1.21.7

View File

@ -22,7 +22,7 @@ const (
type Keyword uint32 type Keyword uint32
var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else"} var Keywords []string = []string{"import", "module", "void", "return", "true", "false", "if", "else", "raw"}
const ( const (
Keyword_Import Keyword = iota Keyword_Import Keyword = iota
@ -33,6 +33,7 @@ const (
KeyWord_False KeyWord_False
Keyword_If Keyword_If
Keyword_Else Keyword_Else
Keyword_Raw
) )
type Separator uint32 type Separator uint32

15
main.go
View File

@ -88,6 +88,7 @@ func readEmbedDir(name string, files map[string]string) {
func main() { func main() {
outputFile := flag.String("o", "a.out", "Output file") outputFile := flag.String("o", "a.out", "Output file")
generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM") generateWAT := flag.Bool("wat", false, "Generate WAT instead of WASM")
wasm64 := flag.Bool("wasm64", false, "Use 64-bit memory (may not be supported in all browsers)")
includeStdlib := flag.Bool("stdlib", true, "Include the standard library") includeStdlib := flag.Bool("stdlib", true, "Include the standard library")
flag.Parse() flag.Parse()
@ -150,7 +151,7 @@ func main() {
parsedFiles = append(parsedFiles, parsed) parsedFiles = append(parsedFiles, parsed)
} }
validator := Validator{files: parsedFiles} validator := Validator{Files: parsedFiles}
errors := validator.validate() errors := validator.validate()
if len(errors) != 0 { if len(errors) != 0 {
for _, err := range errors { for _, err := range errors {
@ -168,7 +169,8 @@ func main() {
// log.Printf("Validated:\n%+#v\n\n", parsedFiles) // log.Printf("Validated:\n%+#v\n\n", parsedFiles)
wat, err := backendWAT(parsedFiles) compiler := Compiler{Files: parsedFiles, Wasm64: *wasm64}
wat, err := compiler.compile()
if err != nil { if err != nil {
if c, ok := err.(CompilerError); ok { if c, ok := err.(CompilerError); ok {
printCompilerError(fileSources, c) printCompilerError(fileSources, c)
@ -193,13 +195,8 @@ func main() {
cmd.Stdin = &input cmd.Stdin = &input
err = cmd.Start() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err, string(output))
}
err = cmd.Wait()
if err != nil {
log.Fatalln(err)
} }
} }

View File

@ -82,6 +82,9 @@ const (
Expression_Tuple Expression_Tuple
Expression_FunctionCall Expression_FunctionCall
Expression_Negate Expression_Negate
Expression_ArrayAccess
Expression_RawMemoryReference
Expression_Cast
) )
type Expression struct { type Expression struct {
@ -122,10 +125,9 @@ const (
) )
type BinaryExpression struct { type BinaryExpression struct {
Operation Operation Operation Operation
Left Expression Left Expression
Right Expression Right Expression
ResultType *Type // Type to expand the operands to before performing the operation
} }
type TupleExpression struct { type TupleExpression struct {
@ -141,6 +143,21 @@ type NegateExpression struct {
Value Expression Value Expression
} }
type ArrayAccessExpression struct {
Array Expression
Index Expression
}
type RawMemoryReferenceExpression struct {
Type Type
Address Expression
}
type CastExpression struct {
Type Type
Value Expression
}
type Local struct { type Local struct {
Name string Name string
Type Type Type Type
@ -293,8 +310,6 @@ func (p *Parser) tryType() (*Type, error) {
} }
if tok.Type == Type_Identifier { if tok.Type == Type_Identifier {
// TODO: array type
var theType Type var theType Type
index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string)) index := slices.Index(PRIMITIVE_TYPE_NAMES, tok.Value.(string))
@ -314,11 +329,15 @@ func (p *Parser) tryType() (*Type, error) {
break break
} }
_, err = pCopy.expectSeparator(Separator_CloseSquare) sep, err = pCopy.trySeparator(Separator_CloseSquare)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if sep == nil {
return nil, nil
}
theType = Type{Type: Type_Array, Value: ArrayType{ElementType: theType}, Position: theType.Position} theType = Type{Type: Type_Array, Value: ArrayType{ElementType: theType}, Position: theType.Position}
} }
@ -421,7 +440,7 @@ func (p *Parser) tryParanthesizedExpression() (*Expression, error) {
return expr, nil return expr, nil
} }
func (p *Parser) tryUnaryExpression() (*Expression, error) { func (p *Parser) tryUnaryExpressionNoArrayAccess() (*Expression, error) {
pCopy := p.copy() pCopy := p.copy()
token := pCopy.peekToken() token := pCopy.peekToken()
@ -457,6 +476,28 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
*p = pCopy *p = pCopy
return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: Literal{Type: Literal_Boolean, Primitive: Primitive_Bool, Value: keyword == Keyword_True}}, Position: token.Position}, nil return &Expression{Type: Expression_Literal, Value: LiteralExpression{Literal: Literal{Type: Literal_Boolean, Primitive: Primitive_Bool, Value: keyword == Keyword_True}}, Position: token.Position}, nil
} }
if keyword == Keyword_Raw {
pCopy.nextToken()
rawType, err := pCopy.expectType()
if err != nil {
return nil, err
}
_, err = pCopy.expectSeparator(Separator_Comma)
if err != nil {
return nil, err
}
address, err := pCopy.expectExpression()
if err != nil {
return nil, err
}
*p = pCopy
return &Expression{Type: Expression_RawMemoryReference, Value: RawMemoryReferenceExpression{Type: *rawType, Address: *address}, Position: token.Position}, nil
}
} }
if token.Type == Type_Operator { if token.Type == Type_Operator {
@ -513,6 +554,43 @@ func (p *Parser) tryUnaryExpression() (*Expression, error) {
return nil, nil return nil, nil
} }
func (p *Parser) tryUnaryExpression() (*Expression, error) {
pCopy := p.copy()
expr, err := pCopy.tryUnaryExpressionNoArrayAccess() // TODO: wrong precedence
if err != nil {
return nil, err
}
if expr == nil {
return nil, nil
}
for {
open, err := pCopy.trySeparator(Separator_OpenSquare)
if err != nil {
return nil, err
}
if open == nil {
*p = pCopy
return expr, nil
}
index, err := pCopy.expectExpression()
if err != nil {
return nil, err
}
_, err = pCopy.expectSeparator(Separator_CloseSquare)
if err != nil {
return nil, err
}
expr = &Expression{Type: Expression_ArrayAccess, Value: ArrayAccessExpression{Array: *expr, Index: *index}, Position: expr.Position}
}
}
func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) { func (p *Parser) tryBinaryExpression0(opFunc func() (*Expression, error), operators ...Operator) (*Expression, error) {
left, err := opFunc() left, err := opFunc()
if err != nil { if err != nil {
@ -578,7 +656,7 @@ func (p *Parser) tryAssignmentExpression() (*Expression, error) {
return nil, nil return nil, nil
} }
if lhs.Type != Expression_VariableReference { // TODO: allow other types if lhs.Type != Expression_VariableReference { // TODO: allow other types (array access)
return p.tryBinaryExpression() return p.tryBinaryExpression()
} }

16
stdlib/alloc.ely Normal file
View File

@ -0,0 +1,16 @@
module alloc;
u64 alloc(u64 size) {
u64 ptr = 0x0u64;
raw(i32, ptr) = 0x03u32;
i32 sus = raw(i32, ptr);
return 0u64;
}
void free(u64 address) {
}
u64 growMemory(u64 numPages) {
return 0u64;
}

View File

@ -1,9 +0,0 @@
module alloc;
u64 alloc(u64 size) {
return 0u64;
}
void free(u64 address) {
}

View File

@ -6,11 +6,43 @@ import (
) )
type Validator struct { type Validator struct {
files []*ParsedFile Files []*ParsedFile
allFunctions map[string]*ParsedFunction AllFunctions map[string]*ParsedFunction
currentBlock *Block CurrentBlock *Block
currentFunction *ParsedFunction CurrentFunction *ParsedFunction
}
func isSameType(a Type, b Type) bool {
if a.Type != b.Type {
return false
}
switch a.Type {
case Type_Primitive:
return a.Value.(PrimitiveType) == b.Value.(PrimitiveType)
case Type_Named:
return a.Value.(NamedType).TypeName == b.Value.(NamedType).TypeName
case Type_Array:
return isSameType(a.Value.(ArrayType).ElementType, b.Value.(ArrayType).ElementType)
case Type_Tuple:
aTuple := a.Value.(TupleType)
bTuple := b.Value.(TupleType)
if len(aTuple.Types) != len(bTuple.Types) {
return false
}
for i := 0; i < len(aTuple.Types); i++ {
if !isSameType(aTuple.Types[i], bTuple.Types[i]) {
return false
}
}
return true
}
panic("type not implemented")
} }
func isTypeExpandableTo(from Type, to Type) bool { func isTypeExpandableTo(from Type, to Type) bool {
@ -40,9 +72,35 @@ func isTypeExpandableTo(from Type, to Type) bool {
return true return true
} }
if from.Type == Type_Array {
return isSameType(from.Value.(ArrayType).ElementType, to.Value.(ArrayType).ElementType)
}
panic("not implemented") panic("not implemented")
} }
func expandExpressionToType(expr *Expression, to Type) {
// TODO: merge with isTypeExpandableTo?
if isSameType(*expr.ValueType, to) {
return
}
if expr.Type == Expression_Tuple {
tupleExpr := expr.Value.(TupleExpression)
tupleType := to.Value.(TupleType)
for i := 0; i < len(tupleType.Types); i++ {
expandExpressionToType(&tupleExpr.Members[i], tupleType.Types[i])
}
expr.Value = tupleExpr
return
}
*expr = Expression{Type: Expression_Cast, Value: CastExpression{Type: to, Value: *expr}, ValueType: &to, Position: expr.Position}
}
func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool { func isPrimitiveTypeExpandableTo(from PrimitiveType, to PrimitiveType) bool {
if from == to { if from == to {
return true return true
@ -118,7 +176,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
switch expr.Type { switch expr.Type {
case Expression_Assignment: case Expression_Assignment:
assignment := expr.Value.(AssignmentExpression) assignment := expr.Value.(AssignmentExpression)
local := getLocal(v.currentBlock, assignment.Variable) local := getLocal(v.CurrentBlock, assignment.Variable)
if local == nil { if local == nil {
errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position)) errors = append(errors, v.createError("Assignment to undeclared variable "+assignment.Variable, expr.Position))
return errors return errors
@ -130,8 +188,12 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) { if !isSameType(*assignment.Value.ValueType, local.Type) {
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position)) if !isTypeExpandableTo(*assignment.Value.ValueType, local.Type) {
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *assignment.Value.ValueType, local.Type), expr.Position))
}
expandExpressionToType(&assignment.Value, local.Type)
} }
expr.ValueType = &local.Type expr.ValueType = &local.Type
@ -147,7 +209,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
} }
case Expression_VariableReference: case Expression_VariableReference:
reference := expr.Value.(VariableReferenceExpression) reference := expr.Value.(VariableReferenceExpression)
local := getLocal(v.currentBlock, reference.Variable) local := getLocal(v.CurrentBlock, reference.Variable)
if local == nil { if local == nil {
errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position)) errors = append(errors, v.createError("Reference to undeclared variable "+reference.Variable, expr.Position))
return errors return errors
@ -170,24 +232,19 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
leftType := binary.Left.ValueType.Value.(PrimitiveType) var operandType Type
rightType := binary.Right.ValueType.Value.(PrimitiveType) if isTypeExpandableTo(*binary.Left.ValueType, *binary.Right.ValueType) {
operandType = *binary.Right.ValueType
var result PrimitiveType = InvalidValue } else if isTypeExpandableTo(*binary.Right.ValueType, *binary.Left.ValueType) {
if isPrimitiveTypeExpandableTo(leftType, rightType) { operandType = *binary.Left.ValueType
result = leftType } else {
} errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", binary.Left.ValueType.Value.(PrimitiveType), binary.Right.ValueType.Value.(PrimitiveType)), expr.Position))
if isPrimitiveTypeExpandableTo(rightType, leftType) {
result = leftType
}
if result == InvalidValue {
errors = append(errors, v.createError(fmt.Sprintf("cannot compare the types %s and %s without an explicit cast", leftType, rightType), expr.Position))
return errors return errors
} }
binary.ResultType = &Type{Type: Type_Primitive, Value: result} expandExpressionToType(&binary.Left, operandType)
expandExpressionToType(&binary.Right, operandType)
expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool} expr.ValueType = &Type{Type: Type_Primitive, Value: Primitive_Bool}
} }
@ -205,7 +262,6 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
binary.ResultType = &Type{Type: Type_Primitive, Value: result}
expr.ValueType = &Type{Type: Type_Primitive, Value: result} expr.ValueType = &Type{Type: Type_Primitive, Value: result}
} }
@ -235,7 +291,7 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
case Expression_FunctionCall: case Expression_FunctionCall:
fc := expr.Value.(FunctionCallExpression) fc := expr.Value.(FunctionCallExpression)
calledFunc, ok := v.allFunctions[fc.Function] calledFunc, ok := v.AllFunctions[fc.Function]
if !ok { if !ok {
errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position)) errors = append(errors, v.createError("call to undefined function '"+fc.Function+"'", expr.Position))
return errors return errors
@ -248,9 +304,11 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
return errors return errors
} }
params := []Expression{*fc.Parameters} params := []*Expression{fc.Parameters}
if fc.Parameters.Type == Expression_Tuple { if fc.Parameters.Type == Expression_Tuple {
params = fc.Parameters.Value.(TupleExpression).Members for i := 0; i < len(fc.Parameters.Value.(TupleExpression).Members); i++ {
params[i] = &fc.Parameters.Value.(TupleExpression).Members[i]
}
} }
if len(params) != len(calledFunc.Parameters) { if len(params) != len(calledFunc.Parameters) {
@ -263,6 +321,8 @@ func (v *Validator) validatePotentiallyVoidExpression(expr *Expression) []error
if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) { if !isTypeExpandableTo(*typeGiven.ValueType, typeExpected.Type) {
errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position)) errors = append(errors, v.createError("invalid type for parameter "+strconv.Itoa(i), expr.Position))
} }
expandExpressionToType(typeGiven, typeExpected.Type)
} }
} }
@ -319,7 +379,7 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
case Statement_Return: case Statement_Return:
ret := stmt.Value.(ReturnStatement) ret := stmt.Value.(ReturnStatement)
if ret.Value != nil { if ret.Value != nil {
if v.currentFunction.ReturnType == nil { if v.CurrentFunction.ReturnType == nil {
errors = append(errors, v.createError("cannot return value from void function", stmt.Position)) errors = append(errors, v.createError("cannot return value from void function", stmt.Position))
return errors return errors
} }
@ -330,10 +390,12 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
return errors return errors
} }
if !isTypeExpandableTo(*ret.Value.ValueType, *v.currentFunction.ReturnType) { if !isTypeExpandableTo(*ret.Value.ValueType, *v.CurrentFunction.ReturnType) {
errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.currentFunction.ReturnType), ret.Value.Position)) errors = append(errors, v.createError(fmt.Sprintf("cannot return value of type %s from function returning %s", *ret.Value.ValueType, *v.CurrentFunction.ReturnType), ret.Value.Position))
} }
} else if v.currentFunction.ReturnType != nil {
expandExpressionToType(ret.Value, *v.CurrentFunction.ReturnType)
} else if v.CurrentFunction.ReturnType != nil {
errors = append(errors, v.createError("missing return value", stmt.Position)) errors = append(errors, v.createError("missing return value", stmt.Position))
} }
@ -349,14 +411,16 @@ func (v *Validator) validateStatement(stmt *Statement, functionLocals *[]Local)
if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) { if !isTypeExpandableTo(*dlv.Initializer.ValueType, dlv.VariableType) {
errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *dlv.Initializer.ValueType, dlv.VariableType), stmt.Position)) errors = append(errors, v.createError(fmt.Sprintf("cannot assign expression of type %s to variable of type %s", *dlv.Initializer.ValueType, dlv.VariableType), stmt.Position))
} }
expandExpressionToType(dlv.Initializer, dlv.VariableType)
} }
if getLocal(v.currentBlock, dlv.Variable) != nil { if getLocal(v.CurrentBlock, dlv.Variable) != nil {
errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position)) errors = append(errors, v.createError("redeclaration of variable '"+dlv.Variable+"'", stmt.Position))
} }
local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)} local := Local{Name: dlv.Variable, Type: dlv.VariableType, IsParameter: false, Index: len(*functionLocals)}
v.currentBlock.Locals[dlv.Variable] = local v.CurrentBlock.Locals[dlv.Variable] = local
*functionLocals = append(*functionLocals, local) *functionLocals = append(*functionLocals, local)
stmt.Value = dlv stmt.Value = dlv
@ -394,7 +458,7 @@ func (v *Validator) validateBlock(block *Block, functionLocals *[]Local) []error
} }
for i := range block.Statements { for i := range block.Statements {
v.currentBlock = block v.CurrentBlock = block
stmt := &block.Statements[i] stmt := &block.Statements[i]
errors = append(errors, v.validateStatement(stmt, functionLocals)...) errors = append(errors, v.validateStatement(stmt, functionLocals)...)
} }
@ -407,7 +471,7 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
var locals []Local var locals []Local
v.currentFunction = function v.CurrentFunction = function
body := function.Body body := function.Body
body.Locals = make(map[string]Local) body.Locals = make(map[string]Local)
@ -429,8 +493,8 @@ func (v *Validator) validateFunction(function *ParsedFunction) []error {
func (v *Validator) validate() []error { func (v *Validator) validate() []error {
var errors []error var errors []error
v.allFunctions = make(map[string]*ParsedFunction) v.AllFunctions = make(map[string]*ParsedFunction)
for _, file := range v.files { for _, file := range v.Files {
for i := range file.Functions { for i := range file.Functions {
function := &file.Functions[i] function := &file.Functions[i]
@ -441,15 +505,15 @@ func (v *Validator) validate() []error {
function.FullName = fullFunctionName function.FullName = fullFunctionName
if _, exists := v.allFunctions[fullFunctionName]; exists { if _, exists := v.AllFunctions[fullFunctionName]; exists {
errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position)) errors = append(errors, v.createError("duplicate function "+fullFunctionName, function.ReturnType.Position))
} }
v.allFunctions[fullFunctionName] = function v.AllFunctions[fullFunctionName] = function
} }
} }
for _, file := range v.files { for _, file := range v.Files {
for i := range file.Imports { for i := range file.Imports {
errors = append(errors, v.validateImport(&file.Imports[i])...) errors = append(errors, v.validateImport(&file.Imports[i])...)
} }