about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/codegen.zig23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 7d5308b..e9eec9b 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -232,12 +232,23 @@ pub const CodeGen = struct {
 
         const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
 
-        const function_return_type = switch (function.node.*) {
-            .FUNCTION_DEFINITION => |x| x.return_type,
-            .PRIMARY_EXPRESSION => |x| x.IDENTIFIER.type.?,
-            .TYPE => |x| x.FUNCTION_TYPE.return_type,
+        var function_return_type: *parser.Node = undefined;
+        switch (function.node.*) {
+            .FUNCTION_DEFINITION => |x| {
+                function_return_type = x.return_type;
+            },
+            .PRIMARY_EXPRESSION => |x| {
+                const f = self.environment.get_variable(x.IDENTIFIER.name).?.node_type;
+                std.debug.assert(f.TYPE == .FUNCTION_TYPE);
+                function_return_type = f.TYPE.FUNCTION_TYPE.return_type;
+            },
+            .TYPE => |x| {
+                function_return_type = x.FUNCTION_TYPE.return_type;
+            },
             else => unreachable,
-        };
+        }
+
+        std.debug.print("FN: {s} -> ret: {any}\n", .{ function_call_statement.expression.PRIMARY_EXPRESSION.IDENTIFIER.name, function_return_type });
 
         return self.create_variable(.{
             .value = res,
@@ -260,7 +271,7 @@ pub const CodeGen = struct {
 
         const val = try self.generate_expression_value(expression.?, null);
 
-        std.debug.print("3TYP {any}: {any} vs {any}\n", .{ expression.?, self.current_function_return_type.?, val.node_type });
+        std.debug.print("3TYP : {any} vs {any}\n", .{ self.current_function_return_type.?, val.node_type });
         std.debug.assert(self.compare_types(self.current_function_return_type.?, val.node_type, false));
 
         _ = llvm.LLVMBuildRet(self.builder, val.value);