about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/codegen.zig35
1 files changed, 20 insertions, 15 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 619e6db..b3fb24a 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -232,21 +232,26 @@ pub const CodeGen = struct {
 
         const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
 
-        var function_return_type: *parser.Node = undefined;
-        switch (function.node.*) {
-            .FUNCTION_DEFINITION => |x| {
-                function_return_type = x.return_type;
-            },
-            .PRIMARY_EXPRESSION => |x| {
-                const f = self.environment.get_variable(x.IDENTIFIER.name).?.node_type;
-                std.debug.assert(f.TYPE == .FUNCTION_TYPE);
-                function_return_type = f.TYPE.FUNCTION_TYPE.return_type;
-            },
-            .TYPE => |x| {
-                function_return_type = x.FUNCTION_TYPE.return_type;
-            },
-            else => unreachable,
-        }
+        const get_function_return_type = struct {
+            fn call(iSelf: *CodeGen, fun: *parser.Node) *parser.Node {
+                switch (fun.*) {
+                    .FUNCTION_DEFINITION => |x| {
+                        return x.return_type;
+                    },
+                    .PRIMARY_EXPRESSION => |x| {
+                        const f = iSelf.environment.get_variable(x.IDENTIFIER.name).?.node_type;
+                        std.debug.assert(f.TYPE == .FUNCTION_TYPE);
+                        return call(iSelf, f);
+                    },
+                    .TYPE => |x| {
+                        return x.FUNCTION_TYPE.return_type;
+                    },
+                    else => unreachable,
+                }
+            }
+        };
+
+        const function_return_type = get_function_return_type.call(self, function.node_type);
 
         std.debug.print("FN: {s} -> ret: {any}\n", .{ function_call_statement.expression.PRIMARY_EXPRESSION.IDENTIFIER.name, function_return_type });