From e50798e9280b861df199399ddb9d9d2ba3afc7a5 Mon Sep 17 00:00:00 2001 From: Simon Boehm Date: Sat, 3 Aug 2024 10:46:36 -0700 Subject: [PATCH] safe softmax --- lleaves/compiler/codegen/codegen.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/lleaves/compiler/codegen/codegen.py b/lleaves/compiler/codegen/codegen.py index bf2e0a3..5b2b5b4 100644 --- a/lleaves/compiler/codegen/codegen.py +++ b/lleaves/compiler/codegen/codegen.py @@ -367,14 +367,20 @@ def _populate_sigmoid(alpha): result = args[0] elif objective == "multiclass": assert len(args) - # TODO Might profit from vectorization, needs testing - result = [builder.call(llvm_exp, [arg]) for arg in args] - + # stable softmax + max_val = args[0] + for arg in args[1:]: + max_val = builder.select( + builder.fcmp_ordered(">", arg, max_val), arg, max_val + ) + exp_vals = [ + builder.call(llvm_exp, [builder.fsub(arg, max_val)]) for arg in args + ] denominator = get_fdtype_const(0.0, use_fp64) - for r in result: - denominator = builder.fadd(r, denominator) - - result = [builder.fdiv(r, denominator) for r in result] + for exp_val in exp_vals: + denominator = builder.fadd(exp_val, denominator) + denominator = builder.fadd(denominator, get_fdtype_const(1e-15, use_fp64)) + result = [builder.fdiv(exp_val, denominator) for exp_val in exp_vals] else: raise ValueError( f"Objective '{objective}' not yet implemented. {ISSUE_ERROR_MSG}"