From fa6d9cdf57db4244a331035d35d1b962630e3ae1 Mon Sep 17 00:00:00 2001 From: Baitinq Date: Sat, 15 Feb 2025 10:56:08 +0100 Subject: Feature: Introduce initial support for function return types --- src/codegen.zig | 8 +++++++- src/parser.zig | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/codegen.zig b/src/codegen.zig index 3884923..ed1890d 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -247,7 +247,8 @@ pub const CodeGen = struct { std.debug.assert(param.PRIMARY_EXPRESSION == .IDENTIFIER); try paramtypes.append(core.LLVMInt64Type()); } - const function_type = core.LLVMFunctionType(core.LLVMInt64Type(), paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError; + const return_type = get_llvm_type(function_definition.return_type); + const function_type = core.LLVMFunctionType(return_type, paramtypes.items.ptr, @intCast(paramtypes.items.len), 0) orelse return CodeGenError.CompilationError; const function = core.LLVMAddFunction(self.llvm_module, try std.fmt.allocPrintZ(self.arena, "{s}", .{name orelse "unnamed_func"}), function_type) orelse return CodeGenError.CompilationError; const function_entry = core.LLVMAppendBasicBlock(function, "entrypoint") orelse return CodeGenError.CompilationError; core.LLVMPositionBuilderAtEnd(self.builder, function_entry); @@ -455,6 +456,11 @@ pub const CodeGen = struct { }); } + fn get_llvm_type(type_name: []const u8) types.LLVMTypeRef { + if (std.mem.eql(u8, type_name, "i64")) return core.LLVMInt64Type(); + unreachable; + } + fn create_print_function(self: *CodeGen) !void { const print_function_type = core.LLVMFunctionType(core.LLVMVoidType(), @constCast(&[_]types.LLVMTypeRef{core.LLVMInt64Type()}), 1, 0); const print_function = core.LLVMAddFunction(self.llvm_module, "print", print_function_type); diff --git a/src/parser.zig b/src/parser.zig index c5eb623..7d71de3 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -62,6 +62,7 @@ pub const Node = union(enum) { FUNCTION_DEFINITION: struct { statements: []*Node, parameters: []*Node, + return_type: []const u8, }, RETURN_STATEMENT: struct { expression: *Node, @@ -376,7 +377,7 @@ pub const Parser = struct { }; } - // FunctionDefinition ::= LPAREN FunctionParamters? RPAREN ARROW LBRACE Statement* ReturnStatement RBRACE + // FunctionDefinition ::= LPAREN FunctionParameters? RPAREN ARROW IDENTIFIER LBRACE Statement* ReturnStatement SEMICOLON RBRACE fn parse_function_definition(self: *Parser) ParserError!*Node { errdefer if (!self.try_context) std.debug.print("Error parsing function definition {any}\n", .{self.peek_token()}); @@ -387,6 +388,10 @@ pub const Parser = struct { _ = try self.parse_token(tokenizer.TokenType.RPAREN); _ = try self.parse_token(tokenizer.TokenType.ARROW); + + const type_expr = try self.parse_primary_expression(); + if (type_expr.PRIMARY_EXPRESSION != .IDENTIFIER) return ParserError.ParsingError; + _ = try self.parse_token(tokenizer.TokenType.LBRACE); var nodes = std.ArrayList(*Node).init(self.arena); @@ -401,6 +406,7 @@ pub const Parser = struct { return self.create_node(.{ .FUNCTION_DEFINITION = .{ .statements = nodes.items, .parameters = parameters, + .return_type = type_expr.PRIMARY_EXPRESSION.IDENTIFIER.name, } }); } -- cgit 1.4.1