about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorBaitinq <[email protected]>2025-07-16 00:42:32 +0200
committerBaitinq <[email protected]>2025-07-16 00:54:23 +0200
commitbc798b9eeb9b5a76a36c28c1d46f7f7ea9c19c40 (patch)
tree2ea2a99fb0f7c0a790173fac825ac78b84a5193b /src
parentFeature: Add support for and and or operators (diff)
downloadpry-lang-bc798b9eeb9b5a76a36c28c1d46f7f7ea9c19c40.tar.gz
pry-lang-bc798b9eeb9b5a76a36c28c1d46f7f7ea9c19c40.tar.bz2
pry-lang-bc798b9eeb9b5a76a36c28c1d46f7f7ea9c19c40.zip
Feature: Implement and/or short circuiting
Diffstat (limited to 'src')
-rw-r--r--src/codegen.pry88
-rw-r--r--src/llvm.pry2
2 files changed, 77 insertions, 13 deletions
diff --git a/src/codegen.pry b/src/codegen.pry
index 4c257b8..ab40c27 100644
--- a/src/codegen.pry
+++ b/src/codegen.pry
@@ -724,23 +724,87 @@ let codegen_generate_expression_value = (c: *codegen, expression: *Node, name: *
 	
 	if ((*expression).type == NODE_LOGICAL_EXPRESSION) {
 		let exp = (*(cast(*NODE_LOGICAL_EXPRESSION_DATA, (*expression).data)));
-		let lhs_value = codegen_generate_expression_value(c, exp.lhs, cast(*i8, null));
-		assert(lhs_value != cast(*Variable, null));
-		let rhs_value = codegen_generate_expression_value(c, exp.rhs, cast(*i8, null));
-		assert(rhs_value != cast(*Variable, null));
+
+		let current_block = cast(*LLVMBasicBlockRef, arena_alloc((*c).arena, sizeof(LLVMBasicBlockRef)));
+		*current_block = LLVMGetInsertBlock((*c).builder);
+		
+		let node_type = Node{};
+		node_type.type = NODE_TYPE_SIMPLE_TYPE;
+		let d = cast(*NODE_TYPE_SIMPLE_TYPE_DATA, arena_alloc((*c).arena, sizeof(NODE_TYPE_SIMPLE_TYPE_DATA)));
+		(*d).name = "bool";
+		(*d).underlying_type = cast(*Node, null);
+		node_type.data = cast(*void, d);
 		
-		assert(compare_types(c, (*lhs_value).node_type, (*rhs_value).node_type, false));
 
 		let result = cast(LLVMValueRef, null);
 		if exp.an {
-			result = LLVMBuildAnd((*c).builder, (*lhs_value).value, (*rhs_value).value, "");
+			let rhs_block = LLVMAppendBasicBlock((*c).current_function, "and_rhs");
+			let merge_block = LLVMAppendBasicBlock((*c).current_function, "and_merge");
+			let lhs_value = codegen_generate_expression_value(c, exp.lhs, cast(*i8, null));
+			assert(lhs_value != cast(*Variable, null));
+
+			LLVMBuildCondBr((*c).builder, (*lhs_value).value, rhs_block, merge_block);
+			LLVMPositionBuilderAtEnd((*c).builder, rhs_block);
+
+			let rhs_value = codegen_generate_expression_value(c, exp.rhs, cast(*i8, null));
+			assert(rhs_value != cast(*Variable, null));
+
+			assert(compare_types(c, (*lhs_value).node_type, (*rhs_value).node_type, false));
+
+			let rhs_end_block = cast(*LLVMBasicBlockRef, arena_alloc((*c).arena, sizeof(LLVMBasicBlockRef)));
+			*rhs_end_block = LLVMGetInsertBlock((*c).builder);
+			LLVMBuildBr((*c).builder, merge_block);	
+			LLVMPositionBuilderAtEnd((*c).builder, merge_block);
+
+			let phi = LLVMBuildPhi((*c).builder, LLVMInt1Type(), "and_result");
+
+			let fals_val = cast(*LLVMValueRef, arena_alloc((*c).arena, sizeof(LLVMValueRef)));
+			let rhs_val = cast(*LLVMValueRef, arena_alloc((*c).arena, sizeof(LLVMValueRef)));
+			*fals_val = LLVMConstInt(LLVMInt1Type(), 0, 0);
+			*rhs_val = ((*rhs_value).value);
+			LLVMAddIncoming(phi, fals_val, current_block, 1);
+			LLVMAddIncoming(phi, rhs_val, rhs_end_block, 1);
+
+			return codegen_generate_literal(c, phi, name, expression, create_node(c, node_type));
 		};
 		if !exp.an {
-			result = LLVMBuildOr((*c).builder, (*lhs_value).value, (*rhs_value).value, "");
+			let rhs_block = LLVMAppendBasicBlock((*c).current_function, "or_rhs");
+			let merge_block = LLVMAppendBasicBlock((*c).current_function, "or_merge");
+
+			let lhs_value = codegen_generate_expression_value(c, exp.lhs, cast(*i8, null));
+			assert(lhs_value != cast(*Variable, null));
+
+			LLVMBuildCondBr((*c).builder, (*lhs_value).value, merge_block, rhs_block);
+
+			LLVMPositionBuilderAtEnd((*c).builder, rhs_block);
+			let rhs_value = codegen_generate_expression_value(c, exp.rhs, cast(*i8, null));
+			assert(rhs_value != cast(*Variable, null));
+
+			assert(compare_types(c, (*lhs_value).node_type, (*rhs_value).node_type, false));
+
+			let rhs_end_block = cast(*LLVMBasicBlockRef, arena_alloc((*c).arena, sizeof(LLVMBasicBlockRef)));
+			*rhs_end_block = LLVMGetInsertBlock((*c).builder);
+			LLVMBuildBr((*c).builder, merge_block);
+
+			LLVMPositionBuilderAtEnd((*c).builder, merge_block);
+
+			let phi = LLVMBuildPhi((*c).builder, LLVMInt1Type(), "or_result");
+
+			let tru_val = cast(*LLVMValueRef, arena_alloc((*c).arena, sizeof(LLVMValueRef)));
+			let rhs_val = cast(*LLVMValueRef, arena_alloc((*c).arena, sizeof(LLVMValueRef)));
+
+			*tru_val = LLVMConstInt(LLVMInt1Type(), 1, 0);
+			*rhs_val = (*rhs_value).value;
+
+			LLVMAddIncoming(phi, tru_val, current_block, 1);
+			LLVMAddIncoming(phi, rhs_val, rhs_end_block, 1);
+
+			return codegen_generate_literal(c, phi, name, expression, create_node(c, node_type));
 		};
-		assert(result != cast(LLVMValueRef, null));
 
-		return codegen_generate_literal(c, result, name, expression, (*lhs_value).node_type);
+		assert(false);
+
+		return cast(*Variable, null);
 	};
 	
 	if ((*expression).type == NODE_EQUALITY_EXPRESSION) {
@@ -1299,10 +1363,8 @@ let codegen_generate_if_statement = (c: *codegen, statement: *NODE_IF_STATEMENT_
 	if last_instr == cast(LLVMValueRef, null) {
 		LLVMBuildBr((*c).builder, merge_block);
 	};
-	if last_instr != cast(LLVMValueRef, null) {
-		if LLVMIsATerminatorInst(last_instr) == cast(LLVMValueRef, null) {
-			LLVMBuildBr((*c).builder, merge_block);
-		};
+	if last_instr != cast(LLVMValueRef, null) and LLVMIsATerminatorInst(last_instr) == cast(LLVMValueRef, null) {
+		LLVMBuildBr((*c).builder, merge_block);
 	};
 	LLVMPositionBuilderAtEnd((*c).builder, current_block);
         LLVMBuildCondBr((*c).builder, (*condition_value).value, then_block, merge_block);
diff --git a/src/llvm.pry b/src/llvm.pry
index 6608372..cdeac4f 100644
--- a/src/llvm.pry
+++ b/src/llvm.pry
@@ -331,6 +331,8 @@ extern LLVMBuildSDiv = (LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *i8) => LLVM
 extern LLVMBuildSRem = (LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *i8) => LLVMValueRef;
 extern LLVMBuildAnd = (LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *i8) => LLVMValueRef;
 extern LLVMBuildOr = (LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *i8) => LLVMValueRef;
+extern LLVMBuildPhi = (LLVMBuilderRef, LLVMTypeRef, *i8) => LLVMValueRef;
+extern LLVMAddIncoming = (LLVMValueRef, *LLVMValueRef, *LLVMBasicBlockRef, i64) => void;
 
 extern LLVMBuildGEP2 = (LLVMBuilderRef, LLVMTypeRef, LLVMValueRef, *LLVMValueRef, i64, *i8) => LLVMValueRef;
 extern LLVMAddGlobal = (LLVMModuleRef, LLVMTypeRef, *i8) => LLVMValueRef;