about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-05-09 23:48:28 +0200
committerBaitinq <[email protected]>2025-05-10 01:04:23 +0200
commit35e8c2eb3d7a00d5030a5f22ac6f1815df84e299 (patch)
tree27fa05e42be021abe6c398db8ae93aa971cb7c6d
parentExamples: Fix example 16 (diff)
downloadpry-lang-35e8c2eb3d7a00d5030a5f22ac6f1815df84e299.tar.gz
pry-lang-35e8c2eb3d7a00d5030a5f22ac6f1815df84e299.tar.bz2
pry-lang-35e8c2eb3d7a00d5030a5f22ac6f1815df84e299.zip
Codegen: Don't rely on llvm types
-rw-r--r--src/codegen.zig158
1 files changed, 73 insertions, 85 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 3969df5..a6d8ed0 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -127,39 +127,50 @@ pub const CodeGen = struct {
 
         if (assignment_statement.lhs.* == .PRIMARY_EXPRESSION) {
             const identifier = assignment_statement.lhs.PRIMARY_EXPRESSION.IDENTIFIER;
+            const variable = try self.generate_expression_value(assignment_statement.rhs, identifier.name);
 
-            if (assignment_statement.is_declaration and self.environment.scope_stack.items.len > 1) {
-                std.debug.assert(assignment_statement.is_dereference == false);
-                // TODO: vv Int64Type is a problem
-                const alloca = llvm.LLVMBuildAlloca(self.builder, llvm.LLVMInt64Type(), try std.fmt.allocPrintZ(self.arena, "{s}", .{identifier.name})); //TODO: Correct type
+            if (self.environment.scope_stack.items.len == 1) {
                 try self.environment.add_variable(identifier.name, try self.create_variable(.{
-                    .value = alloca,
+                    .value = variable.value,
+                    .node = variable.node,
+                    .node_type = variable.node_type,
                     .stack_level = null,
-                    .node = statement,
-                    .node_type = null, // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca
                 }));
+                return;
+            }
+
+            var ptr: llvm.LLVMValueRef = undefined;
+            if (assignment_statement.is_declaration) {
+                var x = try self.get_llvm_type(variable.node_type);
+                if (variable.node_type.TYPE == .FUNCTION_TYPE) {
+                    x = llvm.LLVMPointerType(x, 0);
+                }
+                ptr = llvm.LLVMBuildAlloca(self.builder, x, try std.fmt.allocPrintZ(self.arena, "{s}", .{identifier.name}));
+            } else {
+                ptr = self.environment.get_variable(identifier.name).?.value;
             }
 
-            var undereferenced_variable: ?*Variable = null;
             if (assignment_statement.is_dereference) {
-                const ptr = self.environment.get_variable(identifier.name) orelse unreachable;
-                undereferenced_variable = ptr;
-                const x = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(ptr.node_type.?), ptr.value, "") orelse return CodeGenError.CompilationError;
-                std.debug.assert(ptr.node_type.?.TYPE == .POINTER_TYPE);
-                try self.environment.add_variable(identifier.name, try self.create_variable(.{
-                    .value = x,
-                    .stack_level = null,
-                    .node = statement,
-                    .node_type = ptr.node_type.?.TYPE.POINTER_TYPE.type,
-                }));
+                ptr = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(variable.node_type), ptr, "");
             }
 
-            const variable = try self.generate_expression_value(assignment_statement.rhs, identifier.name);
+            _ = llvm.LLVMBuildStore(self.builder, variable.value, ptr);
 
-            if (!assignment_statement.is_dereference) {
-                try self.environment.add_variable(identifier.name, variable);
+            if (assignment_statement.is_dereference) {
+                ptr = self.environment.get_variable(identifier.name).?.value;
+            }
+
+            const new_variable = try self.create_variable(.{
+                .value = ptr,
+                .node = variable.node,
+                .node_type = variable.node_type,
+                .stack_level = null,
+            });
+            // Adding variable doesnt actually replace the variable of previous scope
+            if (assignment_statement.is_declaration) {
+                try self.environment.add_variable(identifier.name, new_variable);
             } else {
-                try self.environment.add_variable(identifier.name, undereferenced_variable.?);
+                try self.environment.set_variable(identifier.name, new_variable);
             }
         } else {
             const xd = assignment_statement.lhs.UNARY_EXPRESSION.expression;
@@ -182,8 +193,8 @@ pub const CodeGen = struct {
                 std.debug.assert(primary_expression == .IDENTIFIER);
                 function = self.environment.get_variable(primary_expression.IDENTIFIER.name) orelse return CodeGenError.CompilationError;
                 if (llvm.LLVMGetValueKind(function.value) != llvm.LLVMFunctionValueKind) {
-                    function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(try self.get_llvm_type(function.node_type.?), 0), function.value, "");
-                    node = function.node.?;
+                    function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(try self.get_llvm_type(function.node_type), 0), function.value, "");
+                    node = function.node;
                 }
             },
             .FUNCTION_DEFINITION => |*function_definition| {
@@ -199,9 +210,9 @@ pub const CodeGen = struct {
             try arguments.append(arg.value);
         }
 
-        const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type.?), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
+        const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
 
-        const function_return_type = switch (function.node.?.*) {
+        const function_return_type = switch (function.node.*) {
             .FUNCTION_DEFINITION => |x| x.return_type,
             .PRIMARY_EXPRESSION => |x| x.IDENTIFIER.type.?,
             .TYPE => |x| x.FUNCTION_TYPE.return_type,
@@ -308,8 +319,6 @@ pub const CodeGen = struct {
                 try self.environment.create_scope();
                 defer self.environment.drop_scope();
 
-                var ptr: ?*Variable = null;
-
                 const node_type = try self.create_node(.{
                     .TYPE = .{
                         .FUNCTION_TYPE = .{
@@ -319,9 +328,8 @@ pub const CodeGen = struct {
                     },
                 });
 
-                // Needed for recursive functions
+                // // Needed for recursive functions
                 if (name != null) {
-                    ptr = self.environment.get_variable(name.?);
                     try self.environment.add_variable(name.?, try self.create_variable(.{
                         .value = function,
                         .stack_level = null,
@@ -344,7 +352,7 @@ pub const CodeGen = struct {
                     if (param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?.TYPE == .FUNCTION_TYPE) {
                         alloca_param_type = llvm.LLVMPointerType(alloca_param_type.?, 0);
                     }
-                    // We need to alloca params because we assume all identifiers are alloca TODO:: Is this correct
+                    // We need to alloca params because we assume all identifiers are alloca
                     const alloca = llvm.LLVMBuildAlloca(self.builder, alloca_param_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{param_node.PRIMARY_EXPRESSION.IDENTIFIER.name}));
                     _ = llvm.LLVMBuildStore(self.builder, p, alloca);
 
@@ -352,7 +360,7 @@ pub const CodeGen = struct {
                         .value = alloca,
                         .stack_level = null,
                         .node = param_node,
-                        .node_type = param_node.PRIMARY_EXPRESSION.IDENTIFIER.type,
+                        .node_type = param_node.PRIMARY_EXPRESSION.IDENTIFIER.type.?,
                     }));
                 }
 
@@ -373,23 +381,15 @@ pub const CodeGen = struct {
                     });
                 }
 
-                _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError;
-                ptr.?.node_type = node_type;
-                ptr.?.node = expression;
-
-                return ptr.?;
+                return try self.create_variable(.{
+                    .value = function,
+                    .stack_level = null,
+                    .node = expression,
+                    .node_type = node_type,
+                });
             },
             .FUNCTION_CALL_STATEMENT => |*fn_call| {
-                if (name != null) {
-                    const ptr = self.environment.get_variable(name.?) orelse unreachable;
-                    const result = try self.generate_function_call_statement(@ptrCast(fn_call));
-                    _ = llvm.LLVMBuildStore(self.builder, result.value, ptr.value) orelse return CodeGenError.CompilationError;
-                    ptr.node = result.node;
-                    ptr.node_type = result.node_type;
-                    return ptr;
-                } else {
-                    return try self.generate_function_call_statement(@ptrCast(fn_call));
-                }
+                return try self.generate_function_call_statement(@ptrCast(fn_call));
             },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
                 .NUMBER => |n| {
@@ -447,20 +447,13 @@ pub const CodeGen = struct {
                 },
                 .IDENTIFIER => |i| {
                     const variable = self.environment.get_variable(i.name).?;
-                    var param_type = try self.get_llvm_type(variable.node_type.?);
-                    if (variable.node_type.?.TYPE == .FUNCTION_TYPE) {
+                    var param_type = try self.get_llvm_type(variable.node_type);
+                    if (variable.node_type.TYPE == .FUNCTION_TYPE) {
                         param_type = llvm.LLVMPointerType(param_type.?, 0);
                     }
 
-                    var loaded: llvm.LLVMValueRef = undefined;
-
-                    if (variable.node.?.* == .PRIMARY_EXPRESSION and variable.node.?.PRIMARY_EXPRESSION == .STRING) {
-                        loaded = variable.value;
-                    } else {
-                        loaded = llvm.LLVMBuildLoad2(self.builder, param_type, variable.value, "");
-                    }
-
-                    return self.generate_literal(loaded, name, expression, variable.node_type.?);
+                    const loaded = llvm.LLVMBuildLoad2(self.builder, param_type, variable.value, "");
+                    return self.generate_literal(loaded, name, expression, variable.node_type);
                 },
             },
             .ADDITIVE_EXPRESSION => |exp| {
@@ -473,10 +466,9 @@ pub const CodeGen = struct {
                 } } });
 
                 if (exp.addition) {
-                    if (lhs_value.node_type.?.TYPE == .POINTER_TYPE) {
-                        std.debug.print("DEBUG: {any}\n", .{expression});
-                        result = llvm.LLVMBuildGEP2(self.builder, try self.get_llvm_type(lhs_value.node_type.?.TYPE.POINTER_TYPE.type), lhs_value.value, @constCast(&[_]llvm.LLVMValueRef{rhs_value.value}), 1, "");
-                        node_type = lhs_value.node_type.?;
+                    if (lhs_value.node_type.TYPE == .POINTER_TYPE) {
+                        result = llvm.LLVMBuildGEP2(self.builder, try self.get_llvm_type(lhs_value.node_type.TYPE.POINTER_TYPE.type), lhs_value.value, @constCast(&[_]llvm.LLVMValueRef{rhs_value.value}), 1, "");
+                        node_type = lhs_value.node_type;
                     } else {
                         result = llvm.LLVMBuildAdd(self.builder, lhs_value.value, rhs_value.value, "") orelse return CodeGenError.CompilationError;
                     }
@@ -503,16 +495,16 @@ pub const CodeGen = struct {
                     },
                 }
 
-                return self.generate_literal(result, name, expression, lhs_value.node_type.?);
+                return self.generate_literal(result, name, expression, lhs_value.node_type);
             },
             .UNARY_EXPRESSION => |exp| {
                 const k = try self.generate_expression_value(exp.expression, null);
 
                 var r: llvm.LLVMValueRef = undefined;
-                var typ: *parser.Node = k.node_type.?;
+                var typ: *parser.Node = k.node_type;
                 switch (exp.typ) {
                     .NOT => {
-                        std.debug.assert(std.mem.eql(u8, k.node_type.?.TYPE.SIMPLE_TYPE.name, "bool")); //TODO
+                        std.debug.assert(std.mem.eql(u8, k.node_type.TYPE.SIMPLE_TYPE.name, "bool"));
                         r = llvm.LLVMBuildICmp(self.builder, llvm.LLVMIntEQ, k.value, llvm.LLVMConstInt(llvm.LLVMInt1Type(), 0, 0), "");
                         typ = try self.create_node(.{
                             .TYPE = .{
@@ -533,10 +525,9 @@ pub const CodeGen = struct {
                         });
                     },
                     .STAR => {
-                        std.debug.assert(k.node_type.?.TYPE == .POINTER_TYPE);
-                        typ = k.node_type.?.TYPE.POINTER_TYPE.type;
+                        std.debug.assert(k.node_type.TYPE == .POINTER_TYPE);
+                        typ = k.node_type.TYPE.POINTER_TYPE.type;
                         r = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(typ), k.value, "");
-                        std.debug.print("TESTXXX: {any}\n", .{k.node_type.?.TYPE.POINTER_TYPE.type.TYPE});
                     },
                 }
 
@@ -588,21 +579,14 @@ pub const CodeGen = struct {
     }
 
     fn generate_literal(self: *CodeGen, literal_val: llvm.LLVMValueRef, name: ?[]const u8, node: *parser.Node, node_type: *parser.Node) !*Variable {
-        if (name != null) {
-            if (self.environment.scope_stack.items.len == 1) {
-                const ptr = try self.create_variable(.{
-                    .value = llvm.LLVMAddGlobal(self.llvm_module, try self.get_llvm_type(node_type), try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
-                    .stack_level = null,
-                    .node = node,
-                    .node_type = node_type,
-                });
-                llvm.LLVMSetInitializer(ptr.value, literal_val);
-                return ptr;
-            }
-            const ptr = self.environment.get_variable(name.?) orelse unreachable;
-            _ = llvm.LLVMBuildStore(self.builder, literal_val, ptr.value) orelse return CodeGenError.CompilationError;
-            ptr.node = node;
-            ptr.node_type = node_type;
+        if (name != null and self.environment.scope_stack.items.len == 1) {
+            const ptr = try self.create_variable(.{
+                .value = llvm.LLVMAddGlobal(self.llvm_module, try self.get_llvm_type(node_type), try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
+                .stack_level = null,
+                .node = node,
+                .node_type = node_type,
+            });
+            llvm.LLVMSetInitializer(ptr.value, literal_val);
             return ptr;
         }
 
@@ -682,8 +666,8 @@ pub const CodeGen = struct {
 
 const Variable = struct {
     value: llvm.LLVMValueRef,
-    node: ?*parser.Node,
-    node_type: ?*parser.Node,
+    node: *parser.Node,
+    node_type: *parser.Node,
     stack_level: ?usize,
 };
 
@@ -726,6 +710,10 @@ const Environment = struct {
         try self.scope_stack.getLast().variables.put(name, variable);
     }
 
+    fn set_variable(self: *Environment, name: []const u8, variable: *Variable) !void {
+        self.get_variable(name).?.* = variable.*;
+    }
+
     fn get_variable(self: *Environment, name: []const u8) ?*Variable {
         var i = self.scope_stack.items.len;
         var variable: ?*Variable = null;