about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-01-30 20:20:56 +0100
committerBaitinq <[email protected]>2025-01-30 20:21:33 +0100
commit139abaf0a1a49a1e0478764a3e767e0ac69f56d2 (patch)
tree140af3ed05514ae1bca5319015a223c9d8f776f4
parentExamples: Add future types example (diff)
downloadinterpreter-139abaf0a1a49a1e0478764a3e767e0ac69f56d2.tar.gz
interpreter-139abaf0a1a49a1e0478764a3e767e0ac69f56d2.tar.bz2
interpreter-139abaf0a1a49a1e0478764a3e767e0ac69f56d2.zip
Codegen: Cleanup
-rw-r--r--examples/12.src4
-rw-r--r--src/codegen.zig116
-rw-r--r--src/main.zig4
3 files changed, 60 insertions, 64 deletions
diff --git a/examples/12.src b/examples/12.src
index 95f5993..3174197 100644
--- a/examples/12.src
+++ b/examples/12.src
@@ -1,8 +1,8 @@
-let print_int = (n: i32) => {
+let print_int = (n: i32) => i32 {
 	print(n);
 	return n;
 };
 
-let main = (argc: i32) => {
+let main = (argc: i32) => i32 {
 	return print_int(argc);
 };
diff --git a/src/codegen.zig b/src/codegen.zig
index b60ffa6..b8606a2 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -138,52 +138,15 @@ pub const CodeGen = struct {
 
         const assignment_statement = statement.ASSIGNMENT_STATEMENT;
 
-        //tmp
-        const variable_name = try std.fmt.allocPrintZ(self.arena, "{s}", .{assignment_statement.name});
-
-        switch (assignment_statement.expression.*) {
-            .FUNCTION_DEFINITION => {
-                const function_type = core.LLVMFunctionType(core.LLVMInt64Type(), &[_]types.LLVMTypeRef{}, 0, 0) orelse return CodeGenError.CompilationError;
-                const function = core.LLVMAddFunction(self.llvm_module, variable_name, function_type) orelse return CodeGenError.CompilationError;
-                const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
-                core.LLVMPositionBuilderAtEnd(self.builder, function_entry);
-
-                //tmp
-                std.debug.assert(assignment_statement.expression.* == parser.Node.FUNCTION_DEFINITION);
-                const function_defintion = assignment_statement.expression.FUNCTION_DEFINITION;
-
-                for (function_defintion.statements) |stmt| {
-                    try self.generate_statement(stmt);
-                }
-
-                try self.symbol_table.put(variable_name, try self.create_variable(.{
-                    .value = function,
-                    .type = function_type,
-                }));
-            },
-            .PRIMARY_EXPRESSION => |exp| {
-                switch (exp) {
-                    .NUMBER => {
-                        var variable: types.LLVMValueRef = undefined;
-                        if (self.symbol_table.get(variable_name)) |v| {
-                            variable = v.value;
-                        } else {
-                            variable = core.LLVMBuildAlloca(self.builder, core.LLVMInt64Type(), variable_name) orelse return CodeGenError.CompilationError;
-                        }
-                        _ = core.LLVMBuildStore(self.builder, core.LLVMConstInt(core.LLVMInt64Type(), @intCast(exp.NUMBER.value), 0), variable) orelse return CodeGenError.CompilationError;
-                        try self.symbol_table.put(variable_name, try self.create_variable(.{
-                            .value = variable,
-                            .type = core.LLVMInt64Type(),
-                        }));
-                    },
-                    else => unreachable,
-                }
-            },
-            else => unreachable,
+        if (!assignment_statement.is_declaration) {
+            std.debug.assert(self.symbol_table.contains(assignment_statement.name));
         }
+
+        const variable = try self.generate_expression_value(assignment_statement.expression);
+        try self.symbol_table.put(assignment_statement.name, variable);
     }
 
-    fn generate_function_call_statement(self: *CodeGen, statement: *parser.Node) !types.LLVMValueRef {
+    fn generate_function_call_statement(self: *CodeGen, statement: *parser.Node) CodeGenError!types.LLVMValueRef {
         std.debug.assert(statement.* == parser.Node.FUNCTION_CALL_STATEMENT);
         const function_call_statement = statement.FUNCTION_CALL_STATEMENT;
 
@@ -196,17 +159,13 @@ pub const CodeGen = struct {
         var arguments = std.ArrayList(types.LLVMValueRef).init(self.arena);
 
         for (function_call_statement.arguments) |argument| {
-            const num_argument: types.LLVMValueRef = switch (argument.PRIMARY_EXPRESSION) {
-                .NUMBER => |n| core.LLVMConstInt(core.LLVMInt64Type(), @intCast(n.value), 0),
-                .IDENTIFIER => |i| core.LLVMBuildLoad2(self.builder, core.LLVMInt64Type(), self.symbol_table.get(i.name).?.value, "").?,
-                else => unreachable,
-            };
-            try arguments.append(num_argument);
+            const arg = try self.generate_expression_value(argument);
+            try arguments.append(arg.value);
         }
 
         const xd = self.symbol_table.get(ident.name) orelse return CodeGenError.CompilationError;
 
-        return core.LLVMBuildCall2(self.builder, xd.type, xd.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "function_call") orelse return CodeGenError.CompilationError;
+        return core.LLVMBuildCall2(self.builder, xd.type, xd.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError;
     }
 
     fn generate_return_statement(self: *CodeGen, statement: *parser.Node) !void {
@@ -214,29 +173,64 @@ pub const CodeGen = struct {
 
         const expression = statement.RETURN_STATEMENT.expression;
 
-        // TODO: Abstract this as we also need this for decls
-        const num_argument: types.LLVMValueRef = switch (expression.*) {
+        _ = core.LLVMBuildRet(self.builder, (try self.generate_expression_value(expression)).value);
+    }
+
+    fn generate_expression_value(self: *CodeGen, expression: *parser.Node) !*Variable {
+        return switch (expression.*) {
+            .FUNCTION_DEFINITION => |function_definition| {
+                const function_type = core.LLVMFunctionType(core.LLVMInt64Type(), &[_]types.LLVMTypeRef{}, 0, 0) orelse return CodeGenError.CompilationError;
+                const function = core.LLVMAddFunction(self.llvm_module, "", function_type) orelse return CodeGenError.CompilationError;
+                const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
+                core.LLVMPositionBuilderAtEnd(self.builder, function_entry);
+
+                for (function_definition.statements) |stmt| {
+                    try self.generate_statement(stmt);
+                }
+
+                return try self.create_variable(.{
+                    .value = function,
+                    .type = function_type,
+                });
+            },
+            .FUNCTION_CALL_STATEMENT => |*fn_call| {
+                const r = try self.generate_function_call_statement(@ptrCast(fn_call));
+                return try self.create_variable(.{
+                    .value = r,
+                    .type = core.LLVMInt64Type(),
+                });
+            },
             .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) {
-                .NUMBER => |n| core.LLVMConstInt(core.LLVMInt64Type(), @intCast(n.value), 0),
-                .IDENTIFIER => |i| core.LLVMBuildLoad2(self.builder, core.LLVMInt64Type(), self.symbol_table.get(i.name).?.value, "").?,
+                .NUMBER => |n| {
+                    const ptr = core.LLVMBuildAlloca(self.builder, core.LLVMInt64Type(), "") orelse return CodeGenError.CompilationError;
+                    _ = core.LLVMBuildStore(self.builder, core.LLVMConstInt(core.LLVMInt64Type(), @intCast(n.value), 0), ptr) orelse return CodeGenError.CompilationError;
+                    const variable = core.LLVMBuildLoad2(self.builder, core.LLVMInt64Type(), ptr, "") orelse return CodeGenError.CompilationError;
+                    return try self.create_variable(.{
+                        .value = variable,
+                        .type = core.LLVMInt64Type(),
+                    });
+                },
+                .IDENTIFIER => |i| self.symbol_table.get(i.name).?,
                 else => unreachable,
             },
-            .FUNCTION_CALL_STATEMENT => |*fn_call| try self.generate_function_call_statement(@ptrCast(fn_call)),
+            // .ADDITIVE_EXPRESSION => |exp| {
+            //     const lhs_value = self.get_expression_value(exp.lhs);
+            //     const rhs_value = self.get_expression_value(exp.rhs);
+            //
+            //     core.LLVMBuildAdd(self.builder, lhs_value, rhs_value);
+            // },
             else => unreachable,
         };
-
-        _ = core.LLVMBuildRet(self.builder, num_argument);
     }
 
-    pub fn create_entrypoint(self: *CodeGen) CodeGenError!void {
+    fn create_entrypoint(self: *CodeGen) CodeGenError!void {
         const start_function_type = core.LLVMFunctionType(core.LLVMInt8Type(), &[_]types.LLVMTypeRef{}, 0, 0) orelse return CodeGenError.CompilationError;
         const start_function = core.LLVMAddFunction(self.llvm_module, "_start", start_function_type) orelse return CodeGenError.CompilationError;
         const start_function_entry = core.LLVMAppendBasicBlock(start_function, "entrypoint") orelse return CodeGenError.CompilationError;
         core.LLVMPositionBuilderAtEnd(self.builder, start_function_entry);
 
-        const main_function_type = core.LLVMFunctionType(core.LLVMInt8Type(), &[_]types.LLVMTypeRef{}, 0, 0) orelse return CodeGenError.CompilationError;
-        const main_function = core.LLVMGetNamedFunction(self.llvm_module, "main") orelse return CodeGenError.CompilationError;
-        const main_function_return = core.LLVMBuildCall2(self.builder, main_function_type, main_function, &[_]types.LLVMTypeRef{}, 0, "main_call") orelse return CodeGenError.CompilationError;
+        const main_function = self.symbol_table.get("main") orelse return CodeGenError.CompilationError;
+        const main_function_return = core.LLVMBuildCall2(self.builder, main_function.type, main_function.value, &[_]types.LLVMTypeRef{}, 0, "main_call") orelse return CodeGenError.CompilationError;
 
         const exit_func_type = core.LLVMFunctionType(core.LLVMInt8Type(), @constCast(&[_]types.LLVMTypeRef{core.LLVMInt8Type()}), 1, 0);
         const exit_func = core.LLVMAddFunction(self.llvm_module, "exit", exit_func_type);
diff --git a/src/main.zig b/src/main.zig
index 5adec59..b1d8cb3 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -23,7 +23,9 @@ pub fn main() !void {
 
     const source_evaluator = try evaluator.Evaluator.init(arena.allocator());
     const source_codegen = try codegen.CodeGen.init(arena.allocator());
-    defer source_codegen.deinit() catch {};
+    defer source_codegen.deinit() catch |err| {
+        std.debug.print("ERROR GENERATING CODE {any}\n", .{err});
+    };
 
     if (std.mem.eql(u8, path, "-i")) {
         while (true) {