about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-05-04 23:14:05 +0200
committerBaitinq <[email protected]>2025-05-04 23:14:05 +0200
commit34772fc5ac9719f8ad4ce953d93f967aa645ad0d (patch)
tree862616681dd2c4fa3dd36769b57ba6ebb181c644 /src
parentCodegen: Fix regression with additions (diff)
downloadinterpreter-34772fc5ac9719f8ad4ce953d93f967aa645ad0d.tar.gz
interpreter-34772fc5ac9719f8ad4ce953d93f967aa645ad0d.tar.bz2
interpreter-34772fc5ac9719f8ad4ce953d93f967aa645ad0d.zip
Codegen: Simplify by not tracking llvm type
Diffstat (limited to 'src')
-rw-r--r--src/codegen.zig121
1 files changed, 69 insertions, 52 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 3b6ab21..3969df5 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -134,10 +134,9 @@ pub const CodeGen = struct {
                 const alloca = llvm.LLVMBuildAlloca(self.builder, llvm.LLVMInt64Type(), try std.fmt.allocPrintZ(self.arena, "{s}", .{identifier.name})); //TODO: Correct type
                 try self.environment.add_variable(identifier.name, try self.create_variable(.{
                     .value = alloca,
-                    .type = llvm.LLVMVoidType(), // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca
                     .stack_level = null,
                     .node = statement,
-                    .node_type = null,
+                    .node_type = null, // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca
                 }));
             }
 
@@ -145,11 +144,10 @@ pub const CodeGen = struct {
             if (assignment_statement.is_dereference) {
                 const ptr = self.environment.get_variable(identifier.name) orelse unreachable;
                 undereferenced_variable = ptr;
-                const x = llvm.LLVMBuildLoad2(self.builder, ptr.type, ptr.value, "") orelse return CodeGenError.CompilationError;
+                const x = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(ptr.node_type.?), ptr.value, "") orelse return CodeGenError.CompilationError;
                 std.debug.assert(ptr.node_type.?.TYPE == .POINTER_TYPE);
                 try self.environment.add_variable(identifier.name, try self.create_variable(.{
                     .value = x,
-                    .type = ptr.type,
                     .stack_level = null,
                     .node = statement,
                     .node_type = ptr.node_type.?.TYPE.POINTER_TYPE.type,
@@ -184,7 +182,7 @@ pub const CodeGen = struct {
                 std.debug.assert(primary_expression == .IDENTIFIER);
                 function = self.environment.get_variable(primary_expression.IDENTIFIER.name) orelse return CodeGenError.CompilationError;
                 if (llvm.LLVMGetValueKind(function.value) != llvm.LLVMFunctionValueKind) {
-                    function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(function.type, 0), function.value, "");
+                    function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(try self.get_llvm_type(function.node_type.?), 0), function.value, "");
                     node = function.node.?;
                 }
             },
@@ -201,7 +199,7 @@ pub const CodeGen = struct {
             try arguments.append(arg.value);
         }
 
-        const res = llvm.LLVMBuildCall2(self.builder, function.type, function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
+        const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type.?), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
 
         const function_return_type = switch (function.node.?.*) {
             .FUNCTION_DEFINITION => |x| x.return_type,
@@ -210,10 +208,7 @@ pub const CodeGen = struct {
             else => unreachable,
         };
 
-        const typ = try self.get_llvm_type(function_return_type);
-
         return self.create_variable(.{
-            .type = typ,
             .value = res,
             .stack_level = null,
             .node = node,
@@ -290,20 +285,22 @@ pub const CodeGen = struct {
                 // Functions should be declared "globally"
                 const builder_pos = llvm.LLVMGetInsertBlock(self.builder);
 
-                var paramtypes = std.ArrayList(llvm.LLVMTypeRef).init(self.arena);
+                var llvm_param_types = std.ArrayList(llvm.LLVMTypeRef).init(self.arena);
+                var param_types = std.ArrayList(*parser.Node).init(self.arena);
                 for (function_definition.parameters) |param| {
                     std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER);
                     var param_type = try self.get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
                     if (param.PRIMARY_EXPRESSION.IDENTIFIER.type.?.TYPE == .FUNCTION_TYPE) {
                         param_type = llvm.LLVMPointerType(param_type.?, 0);
                     }
-                    try paramtypes.append(param_type);
+                    try llvm_param_types.append(param_type);
+                    try param_types.append(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?);
                 }
                 var return_type = try self.get_llvm_type(function_definition.return_type);
                 if (function_definition.return_type.TYPE == .FUNCTION_TYPE) {
                     return_type = llvm.LLVMPointerType(return_type, 0);
                 }
-                const function_type = llvm.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError;
+                const function_type = llvm.LLVMFunctionType(return_type, llvm_param_types.items.ptr, @intCast(llvm_param_types.items.len), 0) orelse return CodeGenError.CompilationError;
                 const function = llvm.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError;
                 const function_entry = llvm.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
                 llvm.LLVMPositionBuilderAtEnd(self.builder, function_entry);
@@ -313,15 +310,23 @@ pub const CodeGen = struct {
 
                 var ptr: ?*Variable = null;
 
+                const node_type = try self.create_node(.{
+                    .TYPE = .{
+                        .FUNCTION_TYPE = .{
+                            .parameters = param_types.items,
+                            .return_type = function_definition.return_type,
+                        },
+                    },
+                });
+
                 // Needed for recursive functions
                 if (name != null) {
                     ptr = self.environment.get_variable(name.?);
                     try self.environment.add_variable(name.?, try self.create_variable(.{
                         .value = function,
-                        .type = function_type,
                         .stack_level = null,
                         .node = expression,
-                        .node_type = null,
+                        .node_type = node_type,
                     }));
                 }
 
@@ -345,7 +350,6 @@ pub const CodeGen = struct {
 
                     try self.environment.add_variable(param_node.PRIMARY_EXPRESSION.IDENTIFIER.name, try self.create_variable(.{
                         .value = alloca,
-                        .type = param_type,
                         .stack_level = null,
                         .node = param_node,
                         .node_type = param_node.PRIMARY_EXPRESSION.IDENTIFIER.type,
@@ -363,15 +367,14 @@ pub const CodeGen = struct {
                 if (name == null or self.environment.scope_stack.items.len == 2) {
                     return try self.create_variable(.{
                         .value = function,
-                        .type = function_type,
                         .stack_level = null,
                         .node = expression,
-                        .node_type = null,
+                        .node_type = node_type,
                     });
                 }
 
                 _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError;
-                ptr.?.type = function_type;
+                ptr.?.node_type = node_type;
                 ptr.?.node = expression;
 
                 return ptr.?;
@@ -381,7 +384,6 @@ pub const CodeGen = struct {
                     const ptr = self.environment.get_variable(name.?) orelse unreachable;
                     const result = try self.generate_function_call_statement(@ptrCast(fn_call));
                     _ = llvm.LLVMBuildStore(self.builder, result.value, ptr.value) orelse return CodeGenError.CompilationError;
-                    ptr.type = result.type;
                     ptr.node = result.node;
                     ptr.node_type = result.node_type;
                     return ptr;
@@ -391,7 +393,7 @@ pub const CodeGen = struct {
             },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
                 .NUMBER => |n| {
-                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt64Type(), @intCast(n.value), 0), llvm.LLVMInt64Type(), name, expression, try self.create_node(.{
+                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt64Type(), @intCast(n.value), 0), name, expression, try self.create_node(.{
                         .TYPE = .{
                             .SIMPLE_TYPE = .{
                                 .name = "i64",
@@ -405,16 +407,16 @@ pub const CodeGen = struct {
                         true => 1,
                     };
 
-                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt1Type(), @intCast(int_value), 0), llvm.LLVMInt1Type(), name, expression, try self.create_node(.{
+                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt1Type(), @intCast(int_value), 0), name, expression, try self.create_node(.{
                         .TYPE = .{
                             .SIMPLE_TYPE = .{
-                                .name = "i1",
+                                .name = "bool",
                             },
                         },
                     }));
                 },
                 .CHAR => |c| {
-                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt8Type(), @intCast(c.value), 0), llvm.LLVMInt8Type(), name, expression, try self.create_node(.{
+                    return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt8Type(), @intCast(c.value), 0), name, expression, try self.create_node(.{
                         .TYPE = .{
                             .SIMPLE_TYPE = .{
                                 .name = "i8",
@@ -427,7 +429,6 @@ pub const CodeGen = struct {
                     return self.create_variable(
                         .{
                             .value = x,
-                            .type = llvm.LLVMPointerType(llvm.LLVMInt8Type(), 0),
                             .stack_level = null,
                             .node = expression,
                             .node_type = try self.create_node(.{
@@ -446,8 +447,8 @@ pub const CodeGen = struct {
                 },
                 .IDENTIFIER => |i| {
                     const variable = self.environment.get_variable(i.name).?;
-                    var param_type = variable.type;
-                    if (llvm.LLVMGetTypeKind(param_type.?) == llvm.LLVMFunctionTypeKind) {
+                    var param_type = try self.get_llvm_type(variable.node_type.?);
+                    if (variable.node_type.?.TYPE == .FUNCTION_TYPE) {
                         param_type = llvm.LLVMPointerType(param_type.?, 0);
                     }
 
@@ -459,7 +460,7 @@ pub const CodeGen = struct {
                         loaded = llvm.LLVMBuildLoad2(self.builder, param_type, variable.value, "");
                     }
 
-                    return self.generate_literal(loaded, variable.type, name, expression, variable.node_type);
+                    return self.generate_literal(loaded, name, expression, variable.node_type.?);
                 },
             },
             .ADDITIVE_EXPRESSION => |exp| {
@@ -472,7 +473,7 @@ pub const CodeGen = struct {
                 } } });
 
                 if (exp.addition) {
-                    if (llvm.LLVMGetTypeKind(lhs_value.type.?) == llvm.LLVMPointerTypeKind) {
+                    if (lhs_value.node_type.?.TYPE == .POINTER_TYPE) {
                         std.debug.print("DEBUG: {any}\n", .{expression});
                         result = llvm.LLVMBuildGEP2(self.builder, try self.get_llvm_type(lhs_value.node_type.?.TYPE.POINTER_TYPE.type), lhs_value.value, @constCast(&[_]llvm.LLVMValueRef{rhs_value.value}), 1, "");
                         node_type = lhs_value.node_type.?;
@@ -483,7 +484,7 @@ pub const CodeGen = struct {
                     result = llvm.LLVMBuildSub(self.builder, lhs_value.value, rhs_value.value, "") orelse return CodeGenError.CompilationError;
                 }
 
-                return self.generate_literal(result, llvm.LLVMInt64Type(), name, expression, node_type);
+                return self.generate_literal(result, name, expression, node_type);
             },
             .MULTIPLICATIVE_EXPRESSION => |exp| {
                 const lhs_value = try self.generate_expression_value(exp.lhs, null);
@@ -502,34 +503,44 @@ pub const CodeGen = struct {
                     },
                 }
 
-                return self.generate_literal(result, llvm.LLVMInt64Type(), name, expression, lhs_value.node_type);
+                return self.generate_literal(result, name, expression, lhs_value.node_type.?);
             },
             .UNARY_EXPRESSION => |exp| {
                 const k = try self.generate_expression_value(exp.expression, null);
 
                 var r: llvm.LLVMValueRef = undefined;
-                var t: llvm.LLVMTypeRef = undefined;
-                var uwu: *parser.Node = k.node_type.?;
+                var typ: *parser.Node = k.node_type.?;
                 switch (exp.typ) {
                     .NOT => {
-                        std.debug.assert(k.type == llvm.LLVMInt1Type());
+                        std.debug.assert(std.mem.eql(u8, k.node_type.?.TYPE.SIMPLE_TYPE.name, "bool")); //TODO
                         r = llvm.LLVMBuildICmp(self.builder, llvm.LLVMIntEQ, k.value, llvm.LLVMConstInt(llvm.LLVMInt1Type(), 0, 0), "");
-                        t = llvm.LLVMInt1Type();
+                        typ = try self.create_node(.{
+                            .TYPE = .{
+                                .SIMPLE_TYPE = .{
+                                    .name = "bool",
+                                },
+                            },
+                        });
                     },
                     .MINUS => {
                         r = llvm.LLVMBuildNeg(self.builder, k.value, "");
-                        t = llvm.LLVMInt64Type();
+                        typ = try self.create_node(.{
+                            .TYPE = .{
+                                .SIMPLE_TYPE = .{
+                                    .name = "i64",
+                                },
+                            },
+                        });
                     },
                     .STAR => {
-                        r = llvm.LLVMBuildLoad2(self.builder, k.type, k.value, "");
-                        std.debug.print("TEST: {any}\n", .{k.node_type});
                         std.debug.assert(k.node_type.?.TYPE == .POINTER_TYPE);
-                        t = try self.get_llvm_type(k.node_type.?.TYPE.POINTER_TYPE.type);
-                        uwu = k.node_type.?.TYPE.POINTER_TYPE.type;
+                        typ = k.node_type.?.TYPE.POINTER_TYPE.type;
+                        r = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(typ), k.value, "");
+                        std.debug.print("TESTXXX: {any}\n", .{k.node_type.?.TYPE.POINTER_TYPE.type.TYPE});
                     },
                 }
 
-                return self.generate_literal(r, t, name, expression, uwu); //TODO: Why do we need the llvm type at all
+                return self.generate_literal(r, name, expression, typ);
             },
             .EQUALITY_EXPRESSION => |exp| {
                 const lhs_value = try self.generate_expression_value(exp.lhs, null);
@@ -542,7 +553,13 @@ pub const CodeGen = struct {
                 };
                 const cmp = llvm.LLVMBuildICmp(self.builder, op, lhs_value.value, rhs_value.value, "");
 
-                return self.generate_literal(cmp, llvm.LLVMInt1Type(), name, expression, lhs_value.node_type);
+                return self.generate_literal(cmp, name, expression, try self.create_node(.{
+                    .TYPE = .{
+                        .SIMPLE_TYPE = .{
+                            .name = "bool",
+                        },
+                    },
+                }));
             },
             .TYPE => |typ| {
                 std.debug.assert(typ == .FUNCTION_TYPE);
@@ -553,7 +570,6 @@ pub const CodeGen = struct {
                 if (self.environment.scope_stack.items.len == 1) {
                     return try self.create_variable(.{
                         .value = function,
-                        .type = function_type,
                         .stack_level = null,
                         .node = expression,
                         .node_type = expression,
@@ -562,7 +578,6 @@ pub const CodeGen = struct {
 
                 const ptr = self.environment.get_variable(name.?);
                 _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError;
-                ptr.?.type = function_type;
                 ptr.?.node = expression;
                 ptr.?.node_type = expression;
 
@@ -572,12 +587,11 @@ pub const CodeGen = struct {
         };
     }
 
-    fn generate_literal(self: *CodeGen, literal_val: llvm.LLVMValueRef, literal_type: llvm.LLVMTypeRef, name: ?[]const u8, node: *parser.Node, node_type: ?*parser.Node) !*Variable {
+    fn generate_literal(self: *CodeGen, literal_val: llvm.LLVMValueRef, name: ?[]const u8, node: *parser.Node, node_type: *parser.Node) !*Variable {
         if (name != null) {
             if (self.environment.scope_stack.items.len == 1) {
                 const ptr = try self.create_variable(.{
-                    .value = llvm.LLVMAddGlobal(self.llvm_module, literal_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
-                    .type = literal_type,
+                    .value = llvm.LLVMAddGlobal(self.llvm_module, try self.get_llvm_type(node_type), try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})),
                     .stack_level = null,
                     .node = node,
                     .node_type = node_type,
@@ -587,7 +601,6 @@ pub const CodeGen = struct {
             }
             const ptr = self.environment.get_variable(name.?) orelse unreachable;
             _ = llvm.LLVMBuildStore(self.builder, literal_val, ptr.value) orelse return CodeGenError.CompilationError;
-            ptr.type = literal_type;
             ptr.node = node;
             ptr.node_type = node_type;
             return ptr;
@@ -595,7 +608,6 @@ pub const CodeGen = struct {
 
         return try self.create_variable(.{
             .value = literal_val,
-            .type = literal_type,
             .stack_level = null,
             .node = node,
             .node_type = node_type,
@@ -614,9 +626,11 @@ pub const CodeGen = struct {
                 if (std.mem.eql(u8, t.name, "void")) return llvm.LLVMVoidType();
                 unreachable;
             },
-            // TODO: Properly handle this vv
             .FUNCTION_TYPE => |t| {
-                const return_type = try self.get_llvm_type(t.return_type);
+                var return_type = try self.get_llvm_type(t.return_type);
+                if (t.return_type.TYPE == .FUNCTION_TYPE) {
+                    return_type = llvm.LLVMPointerType(return_type, 0);
+                }
                 var paramtypes = std.ArrayList(llvm.LLVMTypeRef).init(self.arena);
                 var is_varargs: i8 = 0;
                 for (t.parameters) |param| {
@@ -624,7 +638,11 @@ pub const CodeGen = struct {
                         is_varargs = 1;
                         continue;
                     }
-                    try paramtypes.append(try self.get_llvm_type(param));
+                    var typ = try self.get_llvm_type(param);
+                    if (param.TYPE == .FUNCTION_TYPE) {
+                        typ = llvm.LLVMPointerType(typ, 0);
+                    }
+                    try paramtypes.append(typ);
                 }
                 const function_type = llvm.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), is_varargs) orelse unreachable;
                 return function_type;
@@ -663,7 +681,6 @@ pub const CodeGen = struct {
 };
 
 const Variable = struct {
-    type: llvm.LLVMTypeRef,
     value: llvm.LLVMValueRef,
     node: ?*parser.Node,
     node_type: ?*parser.Node,