Skip to content

Commit

Permalink
Merge pull request #83 from siboehm/siboehm/safeSoftmax
Browse files Browse the repository at this point in the history
safe softmax
  • Loading branch information
siboehm authored Aug 3, 2024
2 parents a82ec55 + e50798e commit 7b38bac
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,20 @@ def _populate_sigmoid(alpha):
result = args[0]
elif objective == "multiclass":
assert len(args)
# TODO Might profit from vectorization, needs testing
result = [builder.call(llvm_exp, [arg]) for arg in args]

# stable softmax
max_val = args[0]
for arg in args[1:]:
max_val = builder.select(
builder.fcmp_ordered(">", arg, max_val), arg, max_val
)
exp_vals = [
builder.call(llvm_exp, [builder.fsub(arg, max_val)]) for arg in args
]
denominator = get_fdtype_const(0.0, use_fp64)
for r in result:
denominator = builder.fadd(r, denominator)

result = [builder.fdiv(r, denominator) for r in result]
for exp_val in exp_vals:
denominator = builder.fadd(exp_val, denominator)
denominator = builder.fadd(denominator, get_fdtype_const(1e-15, use_fp64))
result = [builder.fdiv(exp_val, denominator) for exp_val in exp_vals]
else:
raise ValueError(
f"Objective '{objective}' not yet implemented. {ISSUE_ERROR_MSG}"
Expand Down

0 comments on commit 7b38bac

Please sign in to comment.