about summary refs log tree commit diff
path: root/src/codegen.zig
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-05-20 23:19:46 +0200
committerBaitinq <[email protected]>2025-05-20 23:19:46 +0200
commitb736233822765222c3963d03749b8f6200000dd3 (patch)
tree868202dd7161897cbc2b0f745791b57d3608e140 /src/codegen.zig
parentFeature: Add support for casting types (diff)
downloadpry-lang-b736233822765222c3963d03749b8f6200000dd3.tar.gz
pry-lang-b736233822765222c3963d03749b8f6200000dd3.tar.bz2
pry-lang-b736233822765222c3963d03749b8f6200000dd3.zip
Feature: Add more type checks
Diffstat (limited to 'src/codegen.zig')
-rw-r--r--src/codegen.zig38
1 files changed, 25 insertions, 13 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index d6415cf..c9e39ea 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -161,16 +161,15 @@ pub const CodeGen = struct {
             } else {
                 ptr = self.environment.get_variable(identifier.name).?.value;
                 typ = self.environment.get_variable(identifier.name).?.node_type;
+                // TODO: Do this in more places! (everywhere get_llvm_type?)  Also check types in return and cmp
+                const expected_type = typ;
+                std.debug.print("TYP {s}: {any} vs {any} -- {any}\n", .{ identifier.name, expected_type.TYPE, variable.node_type.TYPE, variable.node });
+                std.debug.assert(self.compare_types(expected_type, variable.node_type, assignment_statement.is_dereference));
             }
 
             if (assignment_statement.is_dereference) {
                 ptr = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(typ), ptr, "");
-            } else {
-                // TODO: we should still do this with dereferences, but differently
-                // TODO: Do this in more places! (everywhere get_llvm_type?)
-                std.debug.assert(self.compare_types(typ, variable.node_type));
             }
-
             _ = llvm.LLVMBuildStore(self.builder, variable.value, ptr);
 
             if (assignment_statement.is_dereference) {
@@ -187,7 +186,6 @@ pub const CodeGen = struct {
             if (assignment_statement.is_declaration) {
                 try self.environment.add_variable(identifier.name, new_variable);
             } else {
-                // TODO: Dont allow changing types of variables if its not declaration
                 try self.environment.set_variable(identifier.name, new_variable);
             }
         } else {
@@ -223,8 +221,11 @@ pub const CodeGen = struct {
 
         var arguments = std.ArrayList(llvm.LLVMValueRef).init(self.arena);
 
-        for (function_call_statement.arguments) |argument| {
+        for (0.., function_call_statement.arguments) |i, argument| {
             const arg = try self.generate_expression_value(argument, null);
+            const expected_type = function.node_type.TYPE.FUNCTION_TYPE.parameters[i];
+            std.debug.print("TYP {s}: {any} vs {any}\n", .{ function_call_statement.expression.PRIMARY_EXPRESSION.IDENTIFIER.name, expected_type.TYPE, arg.node_type.TYPE });
+            std.debug.assert(self.compare_types(expected_type, arg.node_type, false));
             try arguments.append(arg.value);
         }
 
@@ -458,6 +459,7 @@ pub const CodeGen = struct {
             },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
                 .NULL => {
+                    //TODO: This should likely be *void.
                     return try self.generate_literal(llvm.LLVMConstNull(llvm.LLVMPointerType(llvm.LLVMInt8Type(), 0)), name, expression, try self.create_node(.{
                         .TYPE = .{
                             .POINTER_TYPE = .{
@@ -729,15 +731,24 @@ pub const CodeGen = struct {
         }
     }
 
-    fn compare_types(self: *CodeGen, a: *parser.Node, b: *parser.Node) bool {
+    fn compare_types(self: *CodeGen, a: *parser.Node, b: *parser.Node, is_dereference: bool) bool {
         std.debug.assert(a.* == parser.Node.TYPE);
         std.debug.assert(b.* == parser.Node.TYPE);
 
-        const a_type = a.TYPE;
+        var a_type = a.TYPE;
         const b_type = b.TYPE;
 
+        if (a_type == .SIMPLE_TYPE and std.mem.eql(u8, "varargs", a_type.SIMPLE_TYPE.name)) {
+            return true;
+        }
+
+        if (is_dereference) {
+            a_type = a_type.POINTER_TYPE.type.TYPE;
+        }
+
         if (!std.mem.eql(u8, @tagName(a_type), @tagName(b_type))) {
-            std.debug.print("Tagname mismatch: {s} vs {s}\n", .{ @tagName(a_type), @tagName(b_type) });
+            std.debug.print("Tagname mismatch: {any} vs {any}\n", .{ a_type, b_type });
+            return false;
         }
 
         switch (a_type) {
@@ -752,7 +763,7 @@ pub const CodeGen = struct {
             .FUNCTION_TYPE => |a_func| {
                 const b_func = b_type.FUNCTION_TYPE;
 
-                if (!self.compare_types(a_func.return_type, b_func.return_type)) {
+                if (!self.compare_types(a_func.return_type, b_func.return_type, false)) {
                     std.debug.print("Function return type mismatch\n", .{});
                     return false;
                 }
@@ -763,7 +774,7 @@ pub const CodeGen = struct {
                 }
 
                 for (a_func.parameters, b_func.parameters) |a_param, b_param| {
-                    if (!self.compare_types(a_param, b_param)) {
+                    if (!self.compare_types(a_param, b_param, false)) {
                         std.debug.print("Parameter  type mismatch\n", .{});
                         return false;
                     }
@@ -773,7 +784,8 @@ pub const CodeGen = struct {
             },
             .POINTER_TYPE => |a_ptr| {
                 const b_ptr = b_type.POINTER_TYPE;
-                const res = self.compare_types(a_ptr.type, b_ptr.type);
+
+                const res = self.compare_types(a_ptr.type, b_ptr.type, false);
                 if (!res) {
                     std.debug.print("Pointer base type mismatch\n", .{});
                 }