about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-03-20 23:36:53 +0100
committerBaitinq <[email protected]>2025-03-20 23:36:53 +0100
commit78fad60783eae3799707a50d4d663f4f13707cde (patch)
tree2577faee6530a974b08cbd0dbf9f26070e38dce0 /src
parentExamples: Add new function as return example (diff)
downloadinterpreter-78fad60783eae3799707a50d4d663f4f13707cde.tar.gz
interpreter-78fad60783eae3799707a50d4d663f4f13707cde.tar.bz2
interpreter-78fad60783eae3799707a50d4d663f4f13707cde.zip
Codegen: Support functions as return values
Diffstat (limited to 'src')
-rw-r--r--src/codegen.zig73
1 files changed, 62 insertions, 11 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index b6941db..c53c270 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -50,6 +50,13 @@ pub const CodeGen = struct {
             .value = printf_function,
             .type = printf_function_type,
             .stack_level = null,
+            .node = try self.create_node(.{ .FUNCTION_DEFINITION = .{
+                .statements = &[_]*parser.Node{},
+                .parameters = &[_]*parser.Node{},
+                .return_type = try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{
+                    .name = "i64",
+                } } }),
+            } }),
         }));
 
         try self.create_print_function();
@@ -145,6 +152,7 @@ pub const CodeGen = struct {
                 .value = alloca,
                 .type = llvm.LLVMVoidType(), // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca
                 .stack_level = null,
+                .node = statement,
             }));
         }
 
@@ -152,11 +160,13 @@ pub const CodeGen = struct {
         try self.environment.add_variable(assignment_statement.name, variable);
     }
 
-    fn generate_function_call_statement(self: *CodeGen, statement: *parser.Node) CodeGenError!llvm.LLVMValueRef {
+    fn generate_function_call_statement(self: *CodeGen, statement: *parser.Node) CodeGenError!*Variable {
         errdefer std.debug.print("Error generating function call statement\n", .{});
         std.debug.assert(statement.* == parser.Node.FUNCTION_CALL_STATEMENT);
         const function_call_statement = statement.FUNCTION_CALL_STATEMENT;
 
+        var node = statement;
+
         var function: *Variable = undefined;
         switch (function_call_statement.expression.*) {
             .PRIMARY_EXPRESSION => |primary_expression| {
@@ -164,6 +174,7 @@ pub const CodeGen = struct {
                 function = self.environment.get_variable(primary_expression.IDENTIFIER.name) orelse return CodeGenError.CompilationError;
                 if (llvm.LLVMGetValueKind(function.value) != llvm.LLVMFunctionValueKind) {
                     function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(function.type, 0), function.value, "");
+                    node = function.node.?;
                 }
             },
             .FUNCTION_DEFINITION => |*function_definition| {
@@ -179,7 +190,20 @@ pub const CodeGen = struct {
             try arguments.append(arg.value);
         }
 
-        return llvm.LLVMBuildCall2(self.builder, function.type, function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
+        const res = llvm.LLVMBuildCall2(self.builder, function.type, function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
+
+        const function_return_type = switch (function.node.?.*) {
+            .FUNCTION_DEFINITION => |x| x.return_type,
+            .PRIMARY_EXPRESSION => |x| x.IDENTIFIER.type.?,
+            else => unreachable,
+        };
+
+        return self.create_variable(.{
+            .type = try self.get_llvm_type(function_return_type),
+            .value = res,
+            .stack_level = null,
+            .node = node,
+        }) catch return CodeGenError.CompilationError;
     }
 
     fn generate_return_statement(self: *CodeGen, statement: *parser.Node) !void {
@@ -260,7 +284,10 @@ pub const CodeGen = struct {
                     }
                     try paramtypes.append(param_type);
                 }
-                const return_type = try self.get_llvm_type(function_definition.return_type);
+                var return_type = try self.get_llvm_type(function_definition.return_type);
+                if (function_definition.return_type.TYPE == .FUNCTION_TYPE) {
+                    return_type = llvm.LLVMPointerType(return_type, 0);
+                }
                 const function_type = llvm.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
                 const function = llvm.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError;
                 const function_entry = llvm.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
@@ -278,6 +305,7 @@ pub const CodeGen = struct {
                         .value = function,
                         .type = function_type,
                         .stack_level = null,
+                        .node = expression,
                     }));
                 }
 
@@ -303,6 +331,7 @@ pub const CodeGen = struct {
                         .value = alloca,
                         .type = param_type,
                         .stack_level = null,
+                        .node = param_node,
                     }));
                 }
 
@@ -319,27 +348,26 @@ pub const CodeGen = struct {
                         .value = function,
                         .type = function_type,
                         .stack_level = null,
+                        .node = expression,
                     });
                 }
 
                 _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError;
                 ptr.?.type = function_type;
+                ptr.?.node = expression;
+
                 return ptr.?;
             },
             .FUNCTION_CALL_STATEMENT => |*fn_call| {
                 if (name != null) {
                     const ptr = self.environment.get_variable(name.?) orelse unreachable;
                     const result = try self.generate_function_call_statement(@ptrCast(fn_call));
-                    _ = llvm.LLVMBuildStore(self.builder, result, ptr.value) orelse return CodeGenError.CompilationError;
-                    ptr.type = llvm.LLVMInt64Type();
+                    _ = llvm.LLVMBuildStore(self.builder, result.value, ptr.value) orelse return CodeGenError.CompilationError;
+                    ptr.type = result.type;
+                    ptr.node = result.node;
                     return ptr;
                 } else {
-                    const r = try self.generate_function_call_statement(@ptrCast(fn_call));
-                    return try self.create_variable(.{
-                        .value = r,
-                        .type = llvm.LLVMInt64Type(),
-                        .stack_level = null,
-                    });
+                    return try self.generate_function_call_statement(@ptrCast(fn_call));
                 }
             },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
@@ -440,6 +468,7 @@ pub const CodeGen = struct {
                     .value = llvm.LLVMAddGlobal(self.llvm_module, literal_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
                     .type = literal_type,
                     .stack_level = null,
+                    .node = null, //TODO
                 });
                 llvm.LLVMSetInitializer(ptr.value, literal_val);
                 return ptr;
@@ -454,6 +483,7 @@ pub const CodeGen = struct {
             .value = literal_val,
             .type = literal_type,
             .stack_level = null,
+            .node = null, //TODO
         });
     }
 
@@ -503,6 +533,13 @@ pub const CodeGen = struct {
             .value = print_function,
             .type = print_function_type,
             .stack_level = null,
+            .node = try self.create_node(.{ .FUNCTION_DEFINITION = .{
+                .statements = &[_]*parser.Node{},
+                .parameters = &[_]*parser.Node{},
+                .return_type = try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{
+                    .name = "i64",
+                } } }),
+            } }),
         }));
     }
 
@@ -532,6 +569,13 @@ pub const CodeGen = struct {
             .value = print_function,
             .type = print_function_type,
             .stack_level = null,
+            .node = try self.create_node(.{ .FUNCTION_DEFINITION = .{
+                .statements = &[_]*parser.Node{},
+                .parameters = &[_]*parser.Node{},
+                .return_type = try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{
+                    .name = "i64",
+                } } }),
+            } }),
         }));
     }
 
@@ -540,11 +584,18 @@ pub const CodeGen = struct {
         variable.* = variable_value;
         return variable;
     }
+
+    fn create_node(self: *CodeGen, node_value: parser.Node) !*parser.Node {
+        const node = try self.arena.create(parser.Node);
+        node.* = node_value;
+        return node;
+    }
 };
 
 const Variable = struct {
     type: llvm.LLVMTypeRef,
     value: llvm.LLVMValueRef,
+    node: ?*parser.Node,
     stack_level: ?usize,
 };