about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/codegen.zig8
-rw-r--r--src/parser.zig8
2 files changed, 14 insertions, 2 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 3884923..ed1890d 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -247,7 +247,8 @@ pub const CodeGen = struct {
                     std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER);
                     try paramtypes.append(core.LLVMInt64Type());
                 }
-                const function_type = core.LLVMFunctionType(core.LLVMInt64Type(), paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
+                const return_type = get_llvm_type(function_definition.return_type);
+                const function_type = core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
                 const function = core.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError;
                 const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
                 core.LLVMPositionBuilderAtEnd(self.builder, function_entry);
@@ -455,6 +456,11 @@ pub const CodeGen = struct {
         });
     }
 
+    fn get_llvm_type(type_name: []const u8) types.LLVMTypeRef {
+        if (std.mem.eql(u8, type_name, "i64")) return core.LLVMInt64Type();
+        unreachable;
+    }
+
     fn create_print_function(self: *CodeGen) !void {
         const print_function_type = core.LLVMFunctionType(core.LLVMVoidType(), @constCast(&[_]types.LLVMTypeRef{core.LLVMInt64Type()}), 1, 0);
         const print_function = core.LLVMAddFunction(self.llvm_module, "print", print_function_type);
diff --git a/src/parser.zig b/src/parser.zig
index c5eb623..7d71de3 100644
--- a/src/parser.zig
+++ b/src/parser.zig
@@ -62,6 +62,7 @@ pub const Node = union(enum) {
     FUNCTION_DEFINITION: struct {
         statements: []*Node,
         parameters: []*Node,
+        return_type: []const u8,
     },
     RETURN_STATEMENT: struct {
         expression: *Node,
@@ -376,7 +377,7 @@ pub const Parser = struct {
         };
     }
 
-    // FunctionDefinition ::= LPAREN FunctionParamters? RPAREN ARROW LBRACE Statement* ReturnStatement RBRACE
+    // FunctionDefinition ::= LPAREN FunctionParameters? RPAREN ARROW IDENTIFIER LBRACE Statement* ReturnStatement SEMICOLON RBRACE
     fn parse_function_definition(self: *Parser) ParserError!*Node {
         errdefer if (!self.try_context) std.debug.print("Error parsing function definition {any}\n", .{self.peek_token()});
 
@@ -387,6 +388,10 @@ pub const Parser = struct {
         _ = try self.parse_token(tokenizer.TokenType.RPAREN);
 
         _ = try self.parse_token(tokenizer.TokenType.ARROW);
+
+        const type_expr = try self.parse_primary_expression();
+        if (type_expr.PRIMARY_EXPRESSION != .IDENTIFIER) return ParserError.ParsingError;
+
         _ = try self.parse_token(tokenizer.TokenType.LBRACE);
 
         var nodes = std.ArrayList(*Node).init(self.arena);
@@ -401,6 +406,7 @@ pub const Parser = struct {
         return self.create_node(.{ .FUNCTION_DEFINITION = .{
             .statements = nodes.items,
             .parameters = parameters,
+            .return_type = type_expr.PRIMARY_EXPRESSION.IDENTIFIER.name,
         } });
     }