about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-02-23 21:25:39 +0100
committerBaitinq <[email protected]>2025-02-23 21:25:39 +0100
commit741f9cc124b43b5e24b30110a584c82bc8c211e5 (patch)
tree93080a0138820ca201c7e7275dcfd3aa8c6e9a3d
parentMisc: Fix llvm error message cleanup (diff)
downloadinterpreter-741f9cc124b43b5e24b30110a584c82bc8c211e5.tar.gz
interpreter-741f9cc124b43b5e24b30110a584c82bc8c211e5.tar.bz2
interpreter-741f9cc124b43b5e24b30110a584c82bc8c211e5.zip
Codegen: Support functions as function params
-rw-r--r--examples/11.src13
-rw-r--r--src/codegen.zig117
-rw-r--r--src/parser.zig2
3 files changed, 97 insertions, 35 deletions
diff --git a/examples/11.src b/examples/11.src
index adaabe4..5289f9a 100644
--- a/examples/11.src
+++ b/examples/11.src
@@ -1,12 +1,13 @@
 let main = () => i64 {
-	let x = (a: i64) => i64 {
-		print(a);
-		return 1;
+	let y = (f: (i64) => i64, x: i64) => i64 {
+		return f(x);
 	};
 
-	let y = (f: (i64) => i64) => i64 {
-		return f(2);
+	let id = (a: i64) => i64 {
+		print(a);
+		print(12);
+		return a;
 	};
 
-	return y(x);
+	return y(id, 2);
 };
diff --git a/src/codegen.zig b/src/codegen.zig
index 779f6d6..ec19a7e 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -47,6 +47,7 @@ pub const CodeGen = struct {
         try self.environment.add_variable("printf", try self.create_variable(.{
             .value = printf_function,
             .type = printf_function_type,
+            .stack_level = null,
         }));
 
         try self.create_print_function();
@@ -129,11 +130,13 @@ pub const CodeGen = struct {
 
         const assignment_statement = statement.ASSIGNMENT_STATEMENT;
 
-        if (assignment_statement.is_declaration and self.environment.scope_stack.items.len > 1 and assignment_statement.expression.* != .FUNCTION_DEFINITION) {
+        if (assignment_statement.is_declaration and self.environment.scope_stack.items.len > 1) {
+            // TODO: vv Int64Type is a problem
             const alloca = core.LLVMBuildAlloca(self.builder, core.LLVMInt64Type(), try std.fmt.allocPrintZ(self.arena, "{s}", .{assignment_statement.name})); //TODO: Correct type
             try self.environment.add_variable(assignment_statement.name, try self.create_variable(.{
                 .value = alloca,
                 .type = core.LLVMVoidType(), // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca
+                .stack_level = null,
             }));
         }
 
@@ -151,6 +154,11 @@ pub const CodeGen = struct {
             .PRIMARY_EXPRESSION => |primary_expression| {
                 std.debug.assert(primary_expression == .IDENTIFIER);
                 function = self.environment.get_variable(primary_expression.IDENTIFIER.name) orelse return CodeGenError.CompilationError;
+                std.debug.print("STACK LVEL: {any} {s}\n", .{ function.stack_level.?, primary_expression.IDENTIFIER.name });
+                if (function.stack_level.? > 0) {
+                    std.debug.print("NOT GLOBAL FN! {s} {any}\n", .{ primary_expression.IDENTIFIER.name, function.stack_level });
+                    function.value = core.LLVMBuildLoad2(self.builder, core.LLVMPointerType(function.type, 0), function.value, "");
+                }
             },
             .FUNCTION_DEFINITION => |*function_definition| {
                 function = try self.generate_expression_value(@ptrCast(function_definition), null);
@@ -235,17 +243,18 @@ pub const CodeGen = struct {
         errdefer std.debug.print("Error generating statement value\n", .{});
         return switch (expression.*) {
             .FUNCTION_DEFINITION => |function_definition| {
-                try self.environment.create_scope();
-                defer self.environment.drop_scope();
 
                 // Functions should be declared "globally"
                 const builder_pos = core.LLVMGetInsertBlock(self.builder);
-                defer core.LLVMPositionBuilderAtEnd(self.builder, builder_pos);
 
                 var paramtypes = std.ArrayList(types.LLVMTypeRef).init(self.arena);
                 for (function_definition.parameters) |param| {
                     std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER);
-                    try paramtypes.append(try self.get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?));
+                    var param_type = try self.get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
+                    if (param.PRIMARY_EXPRESSION.IDENTIFIER.type.?.TYPE == .FUNCTION_TYPE) {
+                        param_type = core.LLVMPointerType(param_type.?, 0);
+                    }
+                    try paramtypes.append(param_type);
                 }
                 const return_type = try self.get_llvm_type(function_definition.return_type);
                 const function_type = core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
@@ -255,12 +264,24 @@ pub const CodeGen = struct {
 
                 // Needed for recursive functions
                 if (name != null) {
-                    try self.environment.add_variable(name.?, try self.create_variable(.{
-                        .value = function,
-                        .type = function_type,
-                    }));
+                    const ptr = self.environment.get_variable(name.?);
+                    // Global fn
+                    if (ptr == null) {
+                        try self.environment.add_variable(name.?, try self.create_variable(.{
+                            .value = function,
+                            .type = function_type,
+                            .stack_level = null,
+                        }));
+                    } else {
+                        _ = core.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError;
+                        ptr.?.type = function_type;
+                        try self.environment.add_variable(name.?, ptr.?);
+                    }
                 }
 
+                try self.environment.create_scope();
+                defer self.environment.drop_scope();
+
                 const params = try self.arena.alloc(types.LLVMValueRef, function_definition.parameters.len);
                 core.LLVMGetParams(function, params.ptr);
 
@@ -271,13 +292,18 @@ pub const CodeGen = struct {
                     std.debug.assert(param_node.* == .PRIMARY_EXPRESSION);
 
                     const param_type = try self.get_llvm_type(param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
+                    var alloca_param_type = param_type;
+                    if (param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?.TYPE == .FUNCTION_TYPE) {
+                        alloca_param_type = core.LLVMPointerType(alloca_param_type.?, 0);
+                    }
                     // We need to alloca params because we assume all identifiers are alloca TODO:: Is this correct
-                    const alloca = core.LLVMBuildAlloca(self.builder, param_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{param_node.PRIMARY_EXPRESSION.IDENTIFIER.name}));
+                    const alloca = core.LLVMBuildAlloca(self.builder, alloca_param_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{param_node.PRIMARY_EXPRESSION.IDENTIFIER.name}));
                     _ = core.LLVMBuildStore(self.builder, p, alloca);
 
                     try self.environment.add_variable(param_node.PRIMARY_EXPRESSION.IDENTIFIER.name, try self.create_variable(.{
                         .value = alloca,
                         .type = param_type,
+                        .stack_level = null,
                     }));
                 }
 
@@ -285,10 +311,19 @@ pub const CodeGen = struct {
                     try self.generate_statement(stmt);
                 }
 
-                return try self.create_variable(.{
-                    .value = function,
-                    .type = function_type,
-                });
+                // Global functions
+                if (self.environment.scope_stack.items.len == 2) {
+                    return try self.create_variable(.{
+                        .value = function,
+                        .type = function_type,
+                        .stack_level = null,
+                    });
+                }
+                core.LLVMPositionBuilderAtEnd(self.builder, builder_pos);
+                const ptr = self.environment.get_variable(name.?) orelse unreachable;
+                _ = core.LLVMBuildStore(self.builder, function, ptr.value) orelse return CodeGenError.CompilationError;
+                ptr.type = function_type;
+                return ptr;
             },
             .FUNCTION_CALL_STATEMENT => |*fn_call| {
                 if (name != null) {
@@ -302,12 +337,13 @@ pub const CodeGen = struct {
                     return try self.create_variable(.{
                         .value = r,
                         .type = core.LLVMInt64Type(),
+                        .stack_level = null,
                     });
                 }
             },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
                 .NUMBER => |n| {
-                    return try self.generate_literal(n.value, core.LLVMInt64Type(), name);
+                    return try self.generate_literal(core.LLVMConstInt(core.LLVMInt64Type(), @intCast(n.value), 0), core.LLVMInt64Type(), name);
                 },
                 .BOOLEAN => |b| {
                     const int_value: i64 = switch (b.value) {
@@ -315,11 +351,15 @@ pub const CodeGen = struct {
                         true => 1,
                     };
 
-                    return try self.generate_literal(int_value, core.LLVMInt1Type(), name);
+                    return try self.generate_literal(core.LLVMConstInt(core.LLVMInt1Type(), @intCast(int_value), 0), core.LLVMInt1Type(), name);
                 },
                 .IDENTIFIER => |i| {
                     const variable = self.environment.get_variable(i.name).?;
-                    const loaded = core.LLVMBuildLoad2(self.builder, variable.type, variable.value, "");
+                    var param_type = variable.type;
+                    if (core.LLVMGetTypeKind(param_type.?) == types.LLVMTypeKind.LLVMFunctionTypeKind) {
+                        param_type = core.LLVMPointerType(param_type.?, 0);
+                    }
+                    const loaded = core.LLVMBuildLoad2(self.builder, param_type, variable.value, "");
 
                     if (name != null) {
                         const ptr = self.environment.get_variable(name.?).?;
@@ -331,6 +371,7 @@ pub const CodeGen = struct {
                     return try self.create_variable(.{
                         .value = loaded,
                         .type = variable.type,
+                        .stack_level = null,
                     });
                 },
             },
@@ -352,7 +393,11 @@ pub const CodeGen = struct {
 
                     return ptr;
                 } else {
-                    return try self.create_variable(.{ .value = result, .type = core.LLVMInt64Type() });
+                    return try self.create_variable(.{
+                        .value = result,
+                        .type = core.LLVMInt64Type(),
+                        .stack_level = null,
+                    });
                 }
             },
             .MULTIPLICATIVE_EXPRESSION => |exp| {
@@ -373,7 +418,11 @@ pub const CodeGen = struct {
 
                     return ptr;
                 } else {
-                    return try self.create_variable(.{ .value = result, .type = core.LLVMInt64Type() });
+                    return try self.create_variable(.{
+                        .value = result,
+                        .type = core.LLVMInt64Type(),
+                        .stack_level = null,
+                    });
                 }
             },
             .UNARY_EXPRESSION => |exp| {
@@ -404,6 +453,7 @@ pub const CodeGen = struct {
                     return try self.create_variable(.{
                         .value = r,
                         .type = t,
+                        .stack_level = null,
                     });
                 }
             },
@@ -424,6 +474,7 @@ pub const CodeGen = struct {
                     return try self.create_variable(.{
                         .value = cmp,
                         .type = core.LLVMInt1Type(),
+                        .stack_level = null,
                     });
                 }
             },
@@ -431,30 +482,30 @@ pub const CodeGen = struct {
         };
     }
 
-    fn generate_literal(self: *CodeGen, literal_val: i64, literal_type: types.LLVMTypeRef, name: ?[]const u8) !*Variable {
+    fn generate_literal(self: *CodeGen, literal_val: types.LLVMValueRef, literal_type: types.LLVMTypeRef, name: ?[]const u8) !*Variable {
         var variable: types.LLVMValueRef = undefined;
         if (name != null) {
             if (self.environment.scope_stack.items.len == 1) {
                 const ptr = try self.create_variable(.{
                     .value = core.LLVMAddGlobal(self.llvm_module, literal_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
                     .type = literal_type,
+                    .stack_level = null,
                 });
-                core.LLVMSetInitializer(ptr.value, core.LLVMConstInt(literal_type, @intCast(literal_val), 0));
+                core.LLVMSetInitializer(ptr.value, literal_val);
                 return ptr;
             }
             const ptr = self.environment.get_variable(name.?) orelse unreachable;
-            const val =
-                core.LLVMConstInt(literal_type, @intCast(literal_val), 0);
-            _ = core.LLVMBuildStore(self.builder, val, ptr.value) orelse return CodeGenError.CompilationError;
+            _ = core.LLVMBuildStore(self.builder, literal_val, ptr.value) orelse return CodeGenError.CompilationError;
             ptr.type = literal_type;
             return ptr;
         } else {
-            variable = core.LLVMConstInt(literal_type, @intCast(literal_val), 0);
+            variable = literal_val;
         }
 
         return try self.create_variable(.{
             .value = variable,
             .type = literal_type,
+            .stack_level = null,
         });
     }
 
@@ -475,7 +526,8 @@ pub const CodeGen = struct {
                 for (t.parameters) |param| {
                     try paramtypes.append(try self.get_llvm_type(param));
                 }
-                return core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse unreachable;
+                const function_type = core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse unreachable;
+                return function_type;
             },
         }
     }
@@ -502,6 +554,7 @@ pub const CodeGen = struct {
         try self.environment.add_variable("print", try self.create_variable(.{
             .value = print_function,
             .type = print_function_type,
+            .stack_level = null,
         }));
     }
 
@@ -530,6 +583,7 @@ pub const CodeGen = struct {
         try self.environment.add_variable("printb", try self.create_variable(.{
             .value = print_function,
             .type = print_function_type,
+            .stack_level = null,
         }));
     }
 
@@ -543,6 +597,7 @@ pub const CodeGen = struct {
 const Variable = struct {
     type: types.LLVMTypeRef,
     value: types.LLVMValueRef,
+    stack_level: ?usize,
 };
 
 const Scope = struct {
@@ -586,12 +641,18 @@ const Environment = struct {
 
     fn get_variable(self: *Environment, name: []const u8) ?*Variable {
         var i = self.scope_stack.items.len;
+        var variable: ?*Variable = null;
         while (i > 0) {
             i -= 1;
             const scope = self.scope_stack.items[i];
-            if (scope.variables.get(name)) |v| return v;
+            if (scope.variables.get(name)) |v| {
+                if (variable == null) {
+                    variable = v;
+                }
+                variable.?.stack_level = i;
+            }
         }
-        return null;
+        return variable;
     }
 
     fn contains_variable(self: *Environment, name: []const u8) bool {
diff --git a/src/parser.zig b/src/parser.zig
index 69f58be..f7a6806 100644
--- a/src/parser.zig
+++ b/src/parser.zig
@@ -423,7 +423,7 @@ pub const Parser = struct {
 
     // FunctionParameters ::= IDENTIFIER ":" Type ("," IDENTIFIER ":" Type)*
     fn parse_function_parameters(self: *Parser) ParserError![]*Node {
-        errdefer std.debug.print("Error parsing function parameters {any}\n", .{self.peek_token()});
+        errdefer if (!self.try_context) std.debug.print("Error parsing function parameters {any}\n", .{self.peek_token()});
 
         var node_list = std.ArrayList(*Node).init(self.arena);