diff options
| -rw-r--r-- | src/codegen.zig | 121 |
1 files changed, 69 insertions, 52 deletions
diff --git a/src/codegen.zig b/src/codegen.zig index 3b6ab21..3969df5 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -134,10 +134,9 @@ pub const CodeGen = struct { const alloca = llvm.LLVMBuildAlloca(self.builder, llvm.LLVMInt64Type(), try std.fmt.allocPrintZ(self.arena, "{s}", .{identifier.name})); //TODO: Correct type try self.environment.add_variable(identifier.name, try self.create_variable(.{ .value = alloca, - .type = llvm.LLVMVoidType(), // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca .stack_level = null, .node = statement, - .node_type = null, + .node_type = null, // This gets set to the correct type during the expression type resolution. ALTERNATIVE: Pass the alloca })); } @@ -145,11 +144,10 @@ pub const CodeGen = struct { if (assignment_statement.is_dereference) { const ptr = self.environment.get_variable(identifier.name) orelse unreachable; undereferenced_variable = ptr; - const x = llvm.LLVMBuildLoad2(self.builder, ptr.type, ptr.value, "") orelse return CodeGenError.CompilationError; + const x = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(ptr.node_type.?), ptr.value, "") orelse return CodeGenError.CompilationError; std.debug.assert(ptr.node_type.?.TYPE == .POINTER_TYPE); try self.environment.add_variable(identifier.name, try self.create_variable(.{ .value = x, - .type = ptr.type, .stack_level = null, .node = statement, .node_type = ptr.node_type.?.TYPE.POINTER_TYPE.type, @@ -184,7 +182,7 @@ pub const CodeGen = struct { std.debug.assert(primary_expression == .IDENTIFIER); function = self.environment.get_variable(primary_expression.IDENTIFIER.name) orelse return CodeGenError.CompilationError; if (llvm.LLVMGetValueKind(function.value) != llvm.LLVMFunctionValueKind) { - function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(function.type, 0), function.value, ""); + function.value = llvm.LLVMBuildLoad2(self.builder, llvm.LLVMPointerType(try self.get_llvm_type(function.node_type.?), 0), function.value, ""); node = function.node.?; } }, @@ -201,7 +199,7 @@ pub const CodeGen = struct { try arguments.append(arg.value); } - const res = llvm.LLVMBuildCall2(self.builder, function.type, function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError; + const res = llvm.LLVMBuildCall2(self.builder, try self.get_llvm_type(function.node_type.?), function.value, @ptrCast(arguments.items), @intCast(arguments.items.len), "") orelse return CodeGenError.CompilationError; const function_return_type = switch (function.node.?.*) { .FUNCTION_DEFINITION => |x| x.return_type, @@ -210,10 +208,7 @@ pub const CodeGen = struct { else => unreachable, }; - const typ = try self.get_llvm_type(function_return_type); - return self.create_variable(.{ - .type = typ, .value = res, .stack_level = null, .node = node, @@ -290,20 +285,22 @@ pub const CodeGen = struct { // Functions should be declared "globally" const builder_pos = llvm.LLVMGetInsertBlock(self.builder); - var paramtypes = std.ArrayList(llvm.LLVMTypeRef).init(self.arena); + var llvm_param_types = std.ArrayList(llvm.LLVMTypeRef).init(self.arena); + var param_types = std.ArrayList(*parser.Node).init(self.arena); for (function_definition.parameters) |param| { std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER); var param_type = try self.get_llvm_type(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?); if (param.PRIMARY_EXPRESSION.IDENTIFIER.type.?.TYPE == .FUNCTION_TYPE) { param_type = llvm.LLVMPointerType(param_type.?, 0); } - try paramtypes.append(param_type); + try llvm_param_types.append(param_type); + try param_types.append(param.PRIMARY_EXPRESSION.IDENTIFIER.type.?); } var return_type = try self.get_llvm_type(function_definition.return_type); if (function_definition.return_type.TYPE == .FUNCTION_TYPE) { return_type = llvm.LLVMPointerType(return_type, 0); } - const function_type = llvm.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError; + const function_type = llvm.LLVMFunctionType(return_type, llvm_param_types.items.ptr, @intCast(llvm_param_types.items.len), 0) orelse return CodeGenError.CompilationError; const function = llvm.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError; const function_entry = llvm.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError; llvm.LLVMPositionBuilderAtEnd(self.builder, function_entry); @@ -313,15 +310,23 @@ pub const CodeGen = struct { var ptr: ?*Variable = null; + const node_type = try self.create_node(.{ + .TYPE = .{ + .FUNCTION_TYPE = .{ + .parameters = param_types.items, + .return_type = function_definition.return_type, + }, + }, + }); + // Needed for recursive functions if (name != null) { ptr = self.environment.get_variable(name.?); try self.environment.add_variable(name.?, try self.create_variable(.{ .value = function, - .type = function_type, .stack_level = null, .node = expression, - .node_type = null, + .node_type = node_type, })); } @@ -345,7 +350,6 @@ pub const CodeGen = struct { try self.environment.add_variable(param_node.PRIMARY_EXPRESSION.IDENTIFIER.name, try self.create_variable(.{ .value = alloca, - .type = param_type, .stack_level = null, .node = param_node, .node_type = param_node.PRIMARY_EXPRESSION.IDENTIFIER.type, @@ -363,15 +367,14 @@ pub const CodeGen = struct { if (name == null or self.environment.scope_stack.items.len == 2) { return try self.create_variable(.{ .value = function, - .type = function_type, .stack_level = null, .node = expression, - .node_type = null, + .node_type = node_type, }); } _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError; - ptr.?.type = function_type; + ptr.?.node_type = node_type; ptr.?.node = expression; return ptr.?; @@ -381,7 +384,6 @@ pub const CodeGen = struct { const ptr = self.environment.get_variable(name.?) orelse unreachable; const result = try self.generate_function_call_statement(@ptrCast(fn_call)); _ = llvm.LLVMBuildStore(self.builder, result.value, ptr.value) orelse return CodeGenError.CompilationError; - ptr.type = result.type; ptr.node = result.node; ptr.node_type = result.node_type; return ptr; @@ -391,7 +393,7 @@ pub const CodeGen = struct { }, .PRIMARY_EXPRESSION => |primary_expression| switch (primary_expression) { .NUMBER => |n| { - return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt64Type(), @intCast(n.value), 0), llvm.LLVMInt64Type(), name, expression, try self.create_node(.{ + return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt64Type(), @intCast(n.value), 0), name, expression, try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{ .name = "i64", @@ -405,16 +407,16 @@ pub const CodeGen = struct { true => 1, }; - return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt1Type(), @intCast(int_value), 0), llvm.LLVMInt1Type(), name, expression, try self.create_node(.{ + return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt1Type(), @intCast(int_value), 0), name, expression, try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{ - .name = "i1", + .name = "bool", }, }, })); }, .CHAR => |c| { - return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt8Type(), @intCast(c.value), 0), llvm.LLVMInt8Type(), name, expression, try self.create_node(.{ + return try self.generate_literal(llvm.LLVMConstInt(llvm.LLVMInt8Type(), @intCast(c.value), 0), name, expression, try self.create_node(.{ .TYPE = .{ .SIMPLE_TYPE = .{ .name = "i8", @@ -427,7 +429,6 @@ pub const CodeGen = struct { return self.create_variable( .{ .value = x, - .type = llvm.LLVMPointerType(llvm.LLVMInt8Type(), 0), .stack_level = null, .node = expression, .node_type = try self.create_node(.{ @@ -446,8 +447,8 @@ pub const CodeGen = struct { }, .IDENTIFIER => |i| { const variable = self.environment.get_variable(i.name).?; - var param_type = variable.type; - if (llvm.LLVMGetTypeKind(param_type.?) == llvm.LLVMFunctionTypeKind) { + var param_type = try self.get_llvm_type(variable.node_type.?); + if (variable.node_type.?.TYPE == .FUNCTION_TYPE) { param_type = llvm.LLVMPointerType(param_type.?, 0); } @@ -459,7 +460,7 @@ pub const CodeGen = struct { loaded = llvm.LLVMBuildLoad2(self.builder, param_type, variable.value, ""); } - return self.generate_literal(loaded, variable.type, name, expression, variable.node_type); + return self.generate_literal(loaded, name, expression, variable.node_type.?); }, }, .ADDITIVE_EXPRESSION => |exp| { @@ -472,7 +473,7 @@ pub const CodeGen = struct { } } }); if (exp.addition) { - if (llvm.LLVMGetTypeKind(lhs_value.type.?) == llvm.LLVMPointerTypeKind) { + if (lhs_value.node_type.?.TYPE == .POINTER_TYPE) { std.debug.print("DEBUG: {any}\n", .{expression}); result = llvm.LLVMBuildGEP2(self.builder, try self.get_llvm_type(lhs_value.node_type.?.TYPE.POINTER_TYPE.type), lhs_value.value, @constCast(&[_]llvm.LLVMValueRef{rhs_value.value}), 1, ""); node_type = lhs_value.node_type.?; @@ -483,7 +484,7 @@ pub const CodeGen = struct { result = llvm.LLVMBuildSub(self.builder, lhs_value.value, rhs_value.value, "") orelse return CodeGenError.CompilationError; } - return self.generate_literal(result, llvm.LLVMInt64Type(), name, expression, node_type); + return self.generate_literal(result, name, expression, node_type); }, .MULTIPLICATIVE_EXPRESSION => |exp| { const lhs_value = try self.generate_expression_value(exp.lhs, null); @@ -502,34 +503,44 @@ pub const CodeGen = struct { }, } - return self.generate_literal(result, llvm.LLVMInt64Type(), name, expression, lhs_value.node_type); + return self.generate_literal(result, name, expression, lhs_value.node_type.?); }, .UNARY_EXPRESSION => |exp| { const k = try self.generate_expression_value(exp.expression, null); var r: llvm.LLVMValueRef = undefined; - var t: llvm.LLVMTypeRef = undefined; - var uwu: *parser.Node = k.node_type.?; + var typ: *parser.Node = k.node_type.?; switch (exp.typ) { .NOT => { - std.debug.assert(k.type == llvm.LLVMInt1Type()); + std.debug.assert(std.mem.eql(u8, k.node_type.?.TYPE.SIMPLE_TYPE.name, "bool")); //TODO r = llvm.LLVMBuildICmp(self.builder, llvm.LLVMIntEQ, k.value, llvm.LLVMConstInt(llvm.LLVMInt1Type(), 0, 0), ""); - t = llvm.LLVMInt1Type(); + typ = try self.create_node(.{ + .TYPE = .{ + .SIMPLE_TYPE = .{ + .name = "bool", + }, + }, + }); }, .MINUS => { r = llvm.LLVMBuildNeg(self.builder, k.value, ""); - t = llvm.LLVMInt64Type(); + typ = try self.create_node(.{ + .TYPE = .{ + .SIMPLE_TYPE = .{ + .name = "i64", + }, + }, + }); }, .STAR => { - r = llvm.LLVMBuildLoad2(self.builder, k.type, k.value, ""); - std.debug.print("TEST: {any}\n", .{k.node_type}); std.debug.assert(k.node_type.?.TYPE == .POINTER_TYPE); - t = try self.get_llvm_type(k.node_type.?.TYPE.POINTER_TYPE.type); - uwu = k.node_type.?.TYPE.POINTER_TYPE.type; + typ = k.node_type.?.TYPE.POINTER_TYPE.type; + r = llvm.LLVMBuildLoad2(self.builder, try self.get_llvm_type(typ), k.value, ""); + std.debug.print("TESTXXX: {any}\n", .{k.node_type.?.TYPE.POINTER_TYPE.type.TYPE}); }, } - return self.generate_literal(r, t, name, expression, uwu); //TODO: Why do we need the llvm type at all + return self.generate_literal(r, name, expression, typ); }, .EQUALITY_EXPRESSION => |exp| { const lhs_value = try self.generate_expression_value(exp.lhs, null); @@ -542,7 +553,13 @@ pub const CodeGen = struct { }; const cmp = llvm.LLVMBuildICmp(self.builder, op, lhs_value.value, rhs_value.value, ""); - return self.generate_literal(cmp, llvm.LLVMInt1Type(), name, expression, lhs_value.node_type); + return self.generate_literal(cmp, name, expression, try self.create_node(.{ + .TYPE = .{ + .SIMPLE_TYPE = .{ + .name = "bool", + }, + }, + })); }, .TYPE => |typ| { std.debug.assert(typ == .FUNCTION_TYPE); @@ -553,7 +570,6 @@ pub const CodeGen = struct { if (self.environment.scope_stack.items.len == 1) { return try self.create_variable(.{ .value = function, - .type = function_type, .stack_level = null, .node = expression, .node_type = expression, @@ -562,7 +578,6 @@ pub const CodeGen = struct { const ptr = self.environment.get_variable(name.?); _ = llvm.LLVMBuildStore(self.builder, function, ptr.?.value) orelse return CodeGenError.CompilationError; - ptr.?.type = function_type; ptr.?.node = expression; ptr.?.node_type = expression; @@ -572,12 +587,11 @@ pub const CodeGen = struct { }; } - fn generate_literal(self: *CodeGen, literal_val: llvm.LLVMValueRef, literal_type: llvm.LLVMTypeRef, name: ?[]const u8, node: *parser.Node, node_type: ?*parser.Node) !*Variable { + fn generate_literal(self: *CodeGen, literal_val: llvm.LLVMValueRef, name: ?[]const u8, node: *parser.Node, node_type: *parser.Node) !*Variable { if (name != null) { if (self.environment.scope_stack.items.len == 1) { const ptr = try self.create_variable(.{ - .value = llvm.LLVMAddGlobal(self.llvm_module, literal_type, try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})), - .type = literal_type, + .value = llvm.LLVMAddGlobal(self.llvm_module, try self.get_llvm_type(node_type), try std.fmt.allocPrintZ(self.arena, "{s}", .{name.?})), .stack_level = null, .node = node, .node_type = node_type, @@ -587,7 +601,6 @@ pub const CodeGen = struct { } const ptr = self.environment.get_variable(name.?) orelse unreachable; _ = llvm.LLVMBuildStore(self.builder, literal_val, ptr.value) orelse return CodeGenError.CompilationError; - ptr.type = literal_type; ptr.node = node; ptr.node_type = node_type; return ptr; @@ -595,7 +608,6 @@ pub const CodeGen = struct { return try self.create_variable(.{ .value = literal_val, - .type = literal_type, .stack_level = null, .node = node, .node_type = node_type, @@ -614,9 +626,11 @@ pub const CodeGen = struct { if (std.mem.eql(u8, t.name, "void")) return llvm.LLVMVoidType(); unreachable; }, - // TODO: Properly handle this vv .FUNCTION_TYPE => |t| { - const return_type = try self.get_llvm_type(t.return_type); + var return_type = try self.get_llvm_type(t.return_type); + if (t.return_type.TYPE == .FUNCTION_TYPE) { + return_type = llvm.LLVMPointerType(return_type, 0); + } var paramtypes = std.ArrayList(llvm.LLVMTypeRef).init(self.arena); var is_varargs: i8 = 0; for (t.parameters) |param| { @@ -624,7 +638,11 @@ pub const CodeGen = struct { is_varargs = 1; continue; } - try paramtypes.append(try self.get_llvm_type(param)); + var typ = try self.get_llvm_type(param); + if (param.TYPE == .FUNCTION_TYPE) { + typ = llvm.LLVMPointerType(typ, 0); + } + try paramtypes.append(typ); } const function_type = llvm.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), is_varargs) orelse unreachable; return function_type; @@ -663,7 +681,6 @@ pub const CodeGen = struct { }; const Variable = struct { - type: llvm.LLVMTypeRef, value: llvm.LLVMValueRef, node: ?*parser.Node, node_type: ?*parser.Node, |