about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-02-16 23:24:31 +0100
committerBaitinq <[email protected]>2025-02-16 23:24:31 +0100
commiteb985bc9399d0bc7761273ce933d2f3cf219d320 (patch)
tree6df81a66a7ac61e4c5b4216b33594fe40c4c91c6
parentCodegen: support bool type (diff)
downloadpry-lang-eb985bc9399d0bc7761273ce933d2f3cf219d320.tar.gz
pry-lang-eb985bc9399d0bc7761273ce933d2f3cf219d320.tar.bz2
pry-lang-eb985bc9399d0bc7761273ce933d2f3cf219d320.zip
Parser: Add proper support for type parsing
-rw-r--r--grammar.ebnf8
-rw-r--r--src/codegen.zig30
-rw-r--r--src/parser.zig77
3 files changed, 98 insertions, 17 deletions
diff --git a/grammar.ebnf b/grammar.ebnf
index a169199..ed87d7f 100644
--- a/grammar.ebnf
+++ b/grammar.ebnf
@@ -28,4 +28,10 @@ PrimaryExpression ::= NUMBER | BOOLEAN | IDENTIFIER | FunctionCallStatement | Fu
 
 FunctionDefinition ::= LPAREN FunctionParameters? RPAREN ARROW IDENTIFIER LBRACE Statement* ReturnStatement SEMICOLON RBRACE
 
-FunctionParameters ::= IDENTIFIER ":" IDENTIFIER ("," IDENTIFIER ":" IDENTIFIER)*
+FunctionParameters ::= IDENTIFIER ":" Type ("," IDENTIFIER ":" Type)*
+
+Type ::= IDENTIFIER | FunctionType
+
+FunctionType ::= LPAREN (Type ("," Type)*)? RPAREN ARROW Type
+
+ParameterTypes ::= Type ("," Type)*
diff --git a/src/codegen.zig b/src/codegen.zig
index f096e8e..66271a1 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -245,9 +245,9 @@ pub const CodeGen = struct {
                 var paramtypes = std.ArrayList(types.LLVMTypeRef).init(self.arena);
                 for (function_definition.parameters) |param| {
                     std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER);
-                    try paramtypes.append(get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?));
+                    try paramtypes.append(try self.get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?));
                 }
-                const return_type = get_llvm_type(function_definition.return_type);
+                const return_type = try self.get_llvm_type(function_definition.return_type);
                 const function_type = core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
                 const function = core.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError;
                 const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
@@ -270,7 +270,7 @@ pub const CodeGen = struct {
                     const param_node = function_definition.parameters[parameters_index];
                     std.debug.assert(param_node.* == .PRIMARY_EXPRESSION);
 
-                    const param_type = get_llvm_type(param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
+                    const param_type = try self.get_llvm_type(param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
                     // We need to alloca params because we assume all identifiers are alloca TODO:: Is this correct
                     const alloca = core.LLVMBuildAlloca(self.builder, param_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{param_node.PRIMARY_EXPRESSION.IDENTIFIER.name}));
                     _ = core.LLVMBuildStore(self.builder, p, alloca);
@@ -458,10 +458,26 @@ pub const CodeGen = struct {
         });
     }
 
-    fn get_llvm_type(type_name: []const u8) types.LLVMTypeRef {
-        if (std.mem.eql(u8, type_name, "i64")) return core.LLVMInt64Type();
-        if (std.mem.eql(u8, type_name, "bool")) return core.LLVMInt1Type();
-        unreachable;
+    fn get_llvm_type(self: *CodeGen, node: *parser.Node) !types.LLVMTypeRef {
+        std.debug.assert(node.* == parser.Node.TYPE);
+        const type_node = node.TYPE;
+
+        switch (type_node) {
+            .SIMPLE_TYPE => |t| {
+                if (std.mem.eql(u8, t.name, "i64")) return core.LLVMInt64Type();
+                if (std.mem.eql(u8, t.name, "bool")) return core.LLVMInt1Type();
+                unreachable;
+            },
+            // TODO: Properly handle this vv
+            .FUNCTION_TYPE => |t| {
+                const return_type = try self.get_llvm_type(t.return_type);
+                var paramtypes = std.ArrayList(types.LLVMTypeRef).init(self.arena);
+                for (t.parameters) |param| {
+                    try paramtypes.append(try self.get_llvm_type(param));
+                }
+                return core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse unreachable;
+            },
+        }
     }
 
     fn create_print_function(self: *CodeGen) !void {
diff --git a/src/parser.zig b/src/parser.zig
index 2486662..69f58be 100644
--- a/src/parser.zig
+++ b/src/parser.zig
@@ -57,13 +57,22 @@ pub const Node = union(enum) {
         },
         IDENTIFIER: struct {
             name: []const u8,
-            type: ?[]const u8,
+            type: ?*Node,
         },
     },
     FUNCTION_DEFINITION: struct {
         statements: []*Node,
         parameters: []*Node,
-        return_type: []const u8,
+        return_type: *Node,
+    },
+    TYPE: union(enum) {
+        SIMPLE_TYPE: struct {
+            name: []const u8,
+        },
+        FUNCTION_TYPE: struct {
+            parameters: []*Node,
+            return_type: *Node,
+        },
     },
     RETURN_STATEMENT: struct {
         expression: *Node,
@@ -392,8 +401,7 @@ pub const Parser = struct {
 
         _ = try self.parse_token(tokenizer.TokenType.ARROW);
 
-        const type_expr = try self.parse_primary_expression();
-        if (type_expr.PRIMARY_EXPRESSION != .IDENTIFIER) return ParserError.ParsingError;
+        const return_type = try self.parse_type();
 
         _ = try self.parse_token(tokenizer.TokenType.LBRACE);
 
@@ -409,13 +417,13 @@ pub const Parser = struct {
         return self.create_node(.{ .FUNCTION_DEFINITION = .{
             .statements = nodes.items,
             .parameters = parameters,
-            .return_type = type_expr.PRIMARY_EXPRESSION.IDENTIFIER.name,
+            .return_type = return_type,
         } });
     }
 
-    // FunctionParameters ::= IDENTIFIER ":" IDENTIFIER ("," IDENTIFIER ":" IDENTIFIER)*
+    // FunctionParameters ::= IDENTIFIER ":" Type ("," IDENTIFIER ":" Type)*
     fn parse_function_parameters(self: *Parser) ParserError![]*Node {
-        errdefer if (!self.try_context) std.debug.print("Error parsing function parameters {any}\n", .{self.peek_token()});
+        errdefer std.debug.print("Error parsing function parameters {any}\n", .{self.peek_token()});
 
         var node_list = std.ArrayList(*Node).init(self.arena);
 
@@ -426,14 +434,15 @@ pub const Parser = struct {
             }
             first = false;
             const ident = self.accept_token(tokenizer.TokenType.IDENTIFIER) orelse return node_list.items;
+
             _ = try self.parse_token(tokenizer.TokenType.COLON);
-            const type_ident = try self.parse_token(tokenizer.TokenType.IDENTIFIER);
+            const type_annotation = try self.parse_type();
 
             try node_list.append(try self.create_node(.{
                 .PRIMARY_EXPRESSION = .{
                     .IDENTIFIER = .{
                         .name = try self.arena.dupe(u8, ident.type.IDENTIFIER),
-                        .type = try self.arena.dupe(u8, type_ident.type.IDENTIFIER),
+                        .type = type_annotation,
                     },
                 },
             }));
@@ -457,6 +466,56 @@ pub const Parser = struct {
         });
     }
 
+    // Type ::= IDENTIFIER | FunctionType
+    fn parse_type(self: *Parser) ParserError!*Node {
+        errdefer if (!self.try_context) std.debug.print("Error parsing type annotation {any}\n", .{self.peek_token()});
+
+        return self.accept_parse(parse_function_type) orelse switch (self.consume_token().?.type) {
+            .IDENTIFIER => |ident| {
+                //TODO: we should only accept specific type identifiers
+                return try self.create_node(.{
+                    .TYPE = .{
+                        .SIMPLE_TYPE = .{
+                            .name = try self.arena.dupe(u8, ident),
+                        },
+                    },
+                });
+            },
+            else => ParserError.ParsingError,
+        };
+    }
+
+    // FunctionType ::= LPAREN (Type ("," Type)*)? RPAREN ARROW Type
+    fn parse_function_type(self: *Parser) ParserError!*Node {
+        errdefer if (!self.try_context) std.debug.print("Error parsing function type {any}\n", .{self.peek_token()});
+
+        _ = try self.parse_token(tokenizer.TokenType.LPAREN);
+
+        var parameters = std.ArrayList(*Node).init(self.arena);
+        var first = true;
+        while (self.accept_parse(parse_type)) |type_annotation| {
+            if (!first) {
+                _ = try self.parse_token(tokenizer.TokenType.COMMA);
+            }
+            try parameters.append(type_annotation);
+            first = false;
+        }
+        _ = try self.parse_token(tokenizer.TokenType.RPAREN);
+
+        _ = try self.parse_token(tokenizer.TokenType.ARROW);
+
+        const return_type = try self.parse_type();
+
+        return try self.create_node(.{
+            .TYPE = .{
+                .FUNCTION_TYPE = .{
+                    .parameters = parameters.items,
+                    .return_type = return_type,
+                },
+            },
+        });
+    }
+
     fn parse_token(self: *Parser, expected_token: std.meta.Tag(tokenizer.TokenType)) ParserError!tokenizer.Token {
         errdefer if (!self.try_context) std.debug.print("Error accepting token: {any}\n", .{expected_token});
         const token = self.peek_token() orelse return ParserError.ParsingError;