Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ternary short-circuiting #707

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/language/operator-overloading.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ title: Operator Overloading
---

Spice allows overloading operators for [custom struct types](structs.md).
Currently, this works for the operators `+`, `-`, `*`, `/`, `==`, `!=`, `<<`, `>>`, `+=`, `-=`, `*=`, `/=`,
Currently, this works for the operators `+`, `-`, `*`, `/`, `==`, `!=`, `<<`, `>>`, `+=`, `-=`, `*=`, `/=`, `[]`,
`++` (postfix) and `--` (postfix).
In the future, more operators will be supported for overloading.

Expand Down
44 changes: 34 additions & 10 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,43 @@ std::any IRGenerator::visitTernaryExpr(const TernaryExprNode *node) {
// It is a ternary
// Retrieve the condition value
llvm::Value *condValue = resolveValue(node->condition);

// Get the values of true and false
llvm::Value *trueValue;
llvm::Value *falseValue;
if (node->isShortened) {
trueValue = condValue;
falseValue = resolveValue(node->falseExpr);
const LogicalOrExprNode *trueNode = node->isShortened ? node->condition : node->trueExpr;
const LogicalOrExprNode *falseNode = node->falseExpr;

llvm::Value* resultValue;
if (trueNode->hasCompileTimeValue() && falseNode->hasCompileTimeValue()) {
// If both are constants, we can simply emit a selection instruction
llvm::Value *trueValue = resolveValue(trueNode);
llvm::Value *falseValue = resolveValue(falseNode);
resultValue = builder.CreateSelect(condValue, trueValue, falseValue);
} else {
trueValue = resolveValue(node->trueExpr);
falseValue = resolveValue(node->falseExpr);
// We have at least one non-constant value, use branching to not perform both sides
const std::string codeLoc = node->codeLoc.toPrettyLineAndColumn();
llvm::BasicBlock *condTrue = createBlock("cond.true." + codeLoc);
llvm::BasicBlock *condFalse = createBlock("cond.false." + codeLoc);
llvm::BasicBlock *condExit = createBlock("cond.exit." + codeLoc);

// Jump from original block to true or false block, depending on condition
insertCondJump(condValue, condTrue, condFalse);

// Fill true block
switchToBlock(condTrue);
llvm::Value *trueValue = resolveValue(trueNode);
insertJump(condExit);

// Fill false block
switchToBlock(condFalse);
llvm::Value *falseValue = resolveValue(falseNode);
insertJump(condExit);

// Fill the exit block
switchToBlock(condExit);
llvm::PHINode* phiInst = builder.CreatePHI(trueValue->getType(), 2, "cond.result");
phiInst->addIncoming(trueValue, condTrue);
phiInst->addIncoming(falseValue, condFalse);
resultValue = phiInst;
}

llvm::Value *resultValue = builder.CreateSelect(condValue, trueValue, falseValue);
return LLVMExprResult{.value = resultValue};
}

Expand Down
6 changes: 3 additions & 3 deletions src/irgenerator/IRGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ llvm::Value *IRGenerator::insertInBoundsGEP(llvm::Type *type, llvm::Value *baseP
std::string varName) const {
assert(basePtr->getType()->isPointerTy());
assert(!indices.empty());
assert(std::ranges::all_of(indices, [](llvm::Value *index) {
llvm::Type *indexType = index->getType();
assert(std::ranges::all_of(indices, [](const llvm::Value *index) {
const llvm::Type *indexType = index->getType();
return indexType->isIntegerTy(32) || indexType->isIntegerTy(64);
}));

Expand Down Expand Up @@ -465,7 +465,7 @@ LLVMExprResult IRGenerator::doAssignment(llvm::Value *lhsAddress, SymbolTableEnt

if (isDecl && rhsSType.is(TY_STRUCT) && rhs.isTemporary()) {
assert(lhsEntry != nullptr);
// Directly set the address to the lhs entry
// Directly set the address to the lhs entry (temp stealing)
llvm::Value *rhsAddress = resolveAddress(rhs);
lhsEntry->updateAddress(rhsAddress);
rhs.entry = lhsEntry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ define dso_local i32 @main() #1 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z2f1v()
%2 = call i1 @_Z2f2v()
%3 = select i1 %1, i1 %1, i1 %2
%4 = zext i1 %3 to i32
br i1 %1, label %cond.true.L12C26, label %cond.false.L12C26

cond.true.L12C26: ; preds = %0
%2 = call i1 @_Z2f1v()
br label %cond.exit.L12C26

cond.false.L12C26: ; preds = %0
%3 = call i1 @_Z2f2v()
br label %cond.exit.L12C26

cond.exit.L12C26: ; preds = %cond.false.L12C26, %cond.true.L12C26
%cond.result = phi i1 [ %2, %cond.true.L12C26 ], [ %3, %cond.false.L12C26 ]
%4 = zext i1 %cond.result to i32
%5 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.2, i32 %4)
%6 = load i32, ptr %result, align 4
ret i32 %6
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Result: 3
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %d\00", align 1

define private i1 @_Z10condition1v() {
%result = alloca i1, align 1
ret i1 false
}

define private i1 @_Z10condition2v() {
%result = alloca i1, align 1
ret i1 true
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z10condition1v()
br i1 %1, label %land.1.L10C26, label %land.exit.L10C26

land.1.L10C26: ; preds = %0
%2 = call i1 @_Z10condition2v()
br label %land.exit.L10C26

land.exit.L10C26: ; preds = %land.1.L10C26, %0
%land_phi = phi i1 [ %1, %0 ], [ %2, %land.1.L10C26 ]
%3 = select i1 %land_phi, i32 2, i32 3
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #1

attributes #0 = { noinline nounwind optnone uwtable }
attributes #1 = { nofree nounwind }
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
f<bool> condition1() {
return false;
}

f<bool> condition2() {
return true;
}

f<int> main() {
printf("Result: %d", condition1() && condition2() ? 2: 3);
}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Result: 3
Result: false
Original file line number Diff line number Diff line change
@@ -1,39 +1,64 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %d\00", align 1
@anon.string.0 = private unnamed_addr constant [56 x i8] c"Assertion failed: Condition 'false' evaluated to false.\00", align 1
@anon.string.1 = private unnamed_addr constant [6 x i8] c"false\00", align 1
@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %s\00", align 1

define private i1 @_Z10condition1v() {
define private i1 @_Z7condFctv() {
%result = alloca i1, align 1
ret i1 false
}

define private i1 @_Z10condition2v() {
%result = alloca i1, align 1
ret i1 true
define private ptr @_Z7trueFctv() {
%result = alloca ptr, align 8
br i1 false, label %assert.exit.L6, label %assert.then.L6, !prof !0

assert.then.L6: ; preds = %0
%1 = call i32 (ptr, ...) @printf(ptr @anon.string.0)
call void @exit(i32 1)
unreachable

assert.exit.L6: ; preds = %0
%2 = load ptr, ptr %result, align 8
ret ptr %2
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #0

; Function Attrs: cold noreturn nounwind
declare void @exit(i32) #1

define private ptr @_Z8falseFctv() {
%result = alloca ptr, align 8
ret ptr @anon.string.1
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
define dso_local i32 @main() #2 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z10condition1v()
br i1 %1, label %land.1.L10C26, label %land.exit.L10C26
%1 = call i1 @_Z7condFctv()
br i1 %1, label %cond.true.L15C26, label %cond.false.L15C26

land.1.L10C26: ; preds = %0
%2 = call i1 @_Z10condition2v()
br label %land.exit.L10C26
cond.true.L15C26: ; preds = %0
%2 = call ptr @_Z7trueFctv()
br label %cond.exit.L15C26

land.exit.L10C26: ; preds = %land.1.L10C26, %0
%land_phi = phi i1 [ %1, %0 ], [ %2, %land.1.L10C26 ]
%3 = select i1 %land_phi, i32 2, i32 3
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
cond.false.L15C26: ; preds = %0
%3 = call ptr @_Z8falseFctv()
br label %cond.exit.L15C26

cond.exit.L15C26: ; preds = %cond.false.L15C26, %cond.true.L15C26
%cond.result = phi ptr [ %2, %cond.true.L15C26 ], [ %3, %cond.false.L15C26 ]
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, ptr %cond.result)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #1
attributes #0 = { nofree nounwind }
attributes #1 = { cold noreturn nounwind }
attributes #2 = { noinline nounwind optnone uwtable }

attributes #0 = { noinline nounwind optnone uwtable }
attributes #1 = { nofree nounwind }
!0 = !{!"branch_weights", i32 2000, i32 1}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
f<bool> condition1() {
f<bool> condFct() {
return false;
}

f<bool> condition2() {
return true;
f<string> trueFct() {
assert false; // Should not be called
return "true";
}

f<string> falseFct() {
return "false";
}

f<int> main() {
printf("Result: %d", condition1() && condition2() ? 2: 3);
printf("Result: %s", condFct() ? trueFct() : falseFct());
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@ define dso_local i32 @main() #0 {
store i32 0, ptr %result, align 4
store i1 true, ptr %condition, align 1
%1 = load i1, ptr %condition, align 1
br i1 %1, label %cond.true.L7C13, label %cond.false.L7C13

cond.true.L7C13: ; preds = %0
%2 = call i32 @_Z3getv()
%3 = select i1 %1, i32 %2, i32 24
store i32 %3, ptr %r, align 4
%4 = load i32, ptr %r, align 4
%5 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %4)
%6 = load i32, ptr %result, align 4
ret i32 %6
br label %cond.exit.L7C13

cond.false.L7C13: ; preds = %0
br label %cond.exit.L7C13

cond.exit.L7C13: ; preds = %cond.false.L7C13, %cond.true.L7C13
%cond.result = phi i32 [ %2, %cond.true.L7C13 ], [ 24, %cond.false.L7C13 ]
store i32 %cond.result, ptr %r, align 4
%3 = load i32, ptr %r, align 4
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
Expand Down
Loading