summary refs log tree commit diff
path: root/src/codegen.zig
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/codegen.zig83
1 files changed, 75 insertions, 8 deletions
diff --git a/src/codegen.zig b/src/codegen.zig
index 46c2dd1..f5ef608 100644
--- a/src/codegen.zig
+++ b/src/codegen.zig
@@ -19,7 +19,7 @@ const Variable = struct {
 pub const CodeGen = struct {
     llvm_module: types.LLVMModuleRef,
     builder: types.LLVMBuilderRef,
-    symbol_table: std.StringHashMap(*Variable), //TODO: Scopes, also add functions here
+    environment: *Environment,
 
     arena: std.mem.Allocator,
 
@@ -36,7 +36,7 @@ pub const CodeGen = struct {
         self.* = .{
             .llvm_module = module,
             .builder = builder,
-            .symbol_table = std.StringHashMap(*Variable).init(arena),
+            .environment = try Environment.init(arena),
 
             .arena = arena,
         };
@@ -63,7 +63,7 @@ pub const CodeGen = struct {
         _ = core.LLVMBuildCall2(self.builder, printf_function_type, printf_function, arguments, 2, "") orelse return CodeGenError.CompilationError;
         _ = core.LLVMBuildRetVoid(self.builder);
 
-        try self.symbol_table.put("print", try self.create_variable(.{
+        try self.environment.add_variable("print", try self.create_variable(.{
             .value = print_function,
             .type = print_function_type,
         }));
@@ -139,11 +139,11 @@ pub const CodeGen = struct {
         const assignment_statement = statement.ASSIGNMENT_STATEMENT;
 
         if (!assignment_statement.is_declaration) {
-            std.debug.assert(self.symbol_table.contains(assignment_statement.name));
+            std.debug.assert(self.environment.contains_variable(assignment_statement.name));
         }
 
         const variable = try self.generate_expression_value(assignment_statement.expression);
-        try self.symbol_table.put(assignment_statement.name, variable);
+        try self.environment.add_variable(assignment_statement.name, variable);
     }
 
     fn generate_function_call_statement(self: *CodeGen, statement: *parser.Node) CodeGenError!types.LLVMValueRef {
@@ -155,7 +155,7 @@ pub const CodeGen = struct {
 
         std.debug.assert(primary_expression == .IDENTIFIER);
         const ident = primary_expression.IDENTIFIER;
-        const xd = self.symbol_table.get(ident.name) orelse return CodeGenError.CompilationError;
+        const xd = self.environment.get_variable(ident.name) orelse return CodeGenError.CompilationError;
 
         var arguments = std.ArrayList(types.LLVMValueRef).init(self.arena);
 
@@ -178,6 +178,13 @@ pub const CodeGen = struct {
     fn generate_expression_value(self: *CodeGen, expression: *parser.Node) !*Variable {
         return switch (expression.*) {
             .FUNCTION_DEFINITION => |function_definition| {
+                try self.environment.create_scope();
+                defer self.environment.drop_scope();
+
+                // Functions should be declared "globally"
+                const builder_pos = core.LLVMGetInsertBlock(self.builder);
+                defer core.LLVMPositionBuilderAtEnd(self.builder, builder_pos);
+
                 const function_type = core.LLVMFunctionType(core.LLVMInt64Type(), &[_]types.LLVMTypeRef{}, 0, 0) orelse return CodeGenError.CompilationError;
                 const function = core.LLVMAddFunction(self.llvm_module, "", function_type) orelse return CodeGenError.CompilationError;
                 const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError;
@@ -209,7 +216,7 @@ pub const CodeGen = struct {
                         .type = core.LLVMInt64Type(),
                     });
                 },
-                .IDENTIFIER => |i| self.symbol_table.get(i.name).?,
+                .IDENTIFIER => |i| self.environment.get_variable(i.name).?,
                 else => unreachable,
             },
             .ADDITIVE_EXPRESSION => |exp| {
@@ -232,7 +239,7 @@ pub const CodeGen = struct {
         const start_function_entry = core.LLVMAppendBasicBlock(start_function, "entrypoint") orelse return CodeGenError.CompilationError;
         core.LLVMPositionBuilderAtEnd(self.builder, start_function_entry);
 
-        const main_function = self.symbol_table.get("main") orelse return CodeGenError.CompilationError;
+        const main_function = self.environment.get_variable("main") orelse return CodeGenError.CompilationError;
         const main_function_return = core.LLVMBuildCall2(self.builder, main_function.type, main_function.value, &[_]types.LLVMTypeRef{}, 0, "main_call") orelse return CodeGenError.CompilationError;
 
         const exit_func_type = core.LLVMFunctionType(core.LLVMVoidType(), @constCast(&[_]types.LLVMTypeRef{core.LLVMInt8Type()}), 1, 0);
@@ -247,3 +254,63 @@ pub const CodeGen = struct {
         return variable;
     }
 };
+
+const Scope = struct {
+    variables: std.StringHashMap(*Variable),
+};
+
+const Environment = struct {
+    scope_stack: std.ArrayList(*Scope),
+
+    arena: std.mem.Allocator,
+
+    fn init(arena_allocator: std.mem.Allocator) !*Environment {
+        const self = try arena_allocator.create(Environment);
+
+        self.* = .{
+            .scope_stack = std.ArrayList(*Scope).init(arena_allocator),
+            .arena = arena_allocator,
+        };
+
+        // Create global scope
+        try self.create_scope();
+
+        return self;
+    }
+
+    fn create_scope(self: *Environment) !void {
+        const scope = try self.arena.create(Scope);
+        scope.* = .{
+            .variables = std.StringHashMap(*Variable).init(self.arena),
+        };
+        try self.scope_stack.append(scope);
+    }
+
+    fn drop_scope(self: *Environment) void {
+        _ = self.scope_stack.pop();
+    }
+
+    fn add_variable(self: *Environment, name: []const u8, variable: *Variable) !void {
+        try self.scope_stack.getLast().variables.put(name, variable);
+    }
+
+    fn get_variable(self: *Environment, name: []const u8) ?*Variable {
+        var i = self.scope_stack.items.len;
+        while (i > 0) {
+            i -= 1;
+            const scope = self.scope_stack.items[i];
+            if (scope.variables.get(name)) |v| return v;
+        }
+        return null;
+    }
+
+    fn contains_variable(self: *Environment, name: []const u8) bool {
+        var i = self.scope_stack.items.len;
+        while (i > 0) {
+            i -= 1;
+            const scope = self.scope_stack.items[i];
+            if (scope.variables.contains(name)) return true;
+        }
+        return false;
+    }
+};