diff --git a/src/RankLib.Cli/EvaluateCommand.cs b/src/RankLib.Cli/EvaluateCommand.cs index 2d9e5e8..2c13125 100644 --- a/src/RankLib.Cli/EvaluateCommand.cs +++ b/src/RankLib.Cli/EvaluateCommand.cs @@ -75,7 +75,7 @@ public EvaluateCommand() AddOption(new Option(["-ranker", "--ranker"], () => RankerType.CoordinateAscent, "Ranking algorithm to use")); AddOption(new Option(["-feature", "--feature-description-input-file"], "Feature description file. list features to be considered by the learner, each on a separate line. If not specified, all features will be used.").ExistingOnly()); AddOption(new Option(["-metric2t", "--train-metric"], () => "ERR@10", "Metric to optimize on the training data. Supports MAP, NDCG@k, DCG@k, P@k, RR@k, ERR@k.")); - AddOption(new Option(["-gmax", "--max-label"], "Highest judged relevance label. It affects the calculation of ERR (default=4, i.e. 5-point scale [0,1,2,3,4] where value used is 2^gmax)")); + AddOption(new Option(["-gmax", "--max-label"], () => ERRScorer.DefaultMax, "Highest judged relevance label. It affects the calculation of ERR i.e. 5-point scale [0,1,2,3,4] where value used is 2^gmax")); AddOption(new Option(["-qrel", "--query-relevance-input-file"], "TREC-style relevance judgment file").ExistingOnly()); AddOption(new Option(["-missingZero", "--missing-zero"], "Substitute zero for missing feature values rather than throwing an exception.")); AddOption(new Option(["-validate", "--validate-file"], "Specify if you want to tune your system on the validation data").ExistingOnly()); @@ -428,7 +428,7 @@ public async Task HandleAsync(EvaluateCommandOptions options, CancellationT if (trainMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase) || (testMetric != null && testMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase))) - logger.LogInformation("Highest relevance label (to compute ERR): {HighRelevanceLabel}", (int)SimpleMath.LogBase2(ERRScorer.DefaultMax)); + logger.LogInformation("Highest relevance label (to compute ERR): {HighRelevanceLabel}", options.MaxLabel ?? ERRScorer.DefaultMax); if (options.QueryRelevanceInputFile != null) logger.LogInformation("TREC-format relevance judgment (only affects MAP and NDCG scores): {QueryRelevanceJudgementFile}", options.QueryRelevanceInputFile.FullName); @@ -554,7 +554,7 @@ await evaluator.EvaluateAsync( { logger.LogInformation("Test metric: {TestMetric}", testMetric); if (testMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase)) - logger.LogInformation("Highest relevance label (to compute ERR): {HighestRelevanceLabel}", options.MaxLabel ?? (int)SimpleMath.LogBase2(ERRScorer.DefaultMax)); + logger.LogInformation("Highest relevance label (to compute ERR): {HighestRelevanceLabel}", options.MaxLabel ?? ERRScorer.DefaultMax); if (savedModelFiles.Count > 1) { diff --git a/src/RankLib/Eval/Analyzer.cs b/src/RankLib/Eval/Analyzer.cs index c81b4d9..704621c 100644 --- a/src/RankLib/Eval/Analyzer.cs +++ b/src/RankLib/Eval/Analyzer.cs @@ -20,7 +20,7 @@ public Analyzer(ISignificanceTest test, ILogger? logger = null) _logger = logger ?? NullLogger.Instance; } - public class Result + private class Result { public int Status = 0; public int Win = 0; @@ -85,8 +85,7 @@ public void Compare(List targetFiles, string baseFile) _logger.LogInformation("Overall comparison"); _logger.LogInformation("System\tPerformance\tImprovement\tWin\tLoss\tp-value"); - - _logger.LogInformation($"{Path.GetFileName(baseFile)} [baseline]\t{basePerformance["all"]:F4}"); + _logger.LogInformation("{FileName} [baseline]\t{Performance}", Path.GetFileName(baseFile), basePerformance["all"].ToString("F4")); for (var i = 0; i < results.Length; i++) { @@ -131,7 +130,7 @@ public void Compare(List targetFiles, string baseFile) } } - public Result[] Compare(Dictionary basePerformance, List> targets) + private Result[] Compare(Dictionary basePerformance, List> targets) { var results = new Result[targets.Count]; for (var i = 0; i < targets.Count; i++) @@ -140,7 +139,7 @@ public Result[] Compare(Dictionary basePerformance, List basePerformance, Dictionary target) + private Result Compare(Dictionary basePerformance, Dictionary target) { var result = new Result { diff --git a/src/RankLib/Learning/Combiner.cs b/src/RankLib/Learning/Combiner.cs index 0c447fc..7952357 100644 --- a/src/RankLib/Learning/Combiner.cs +++ b/src/RankLib/Learning/Combiner.cs @@ -1,4 +1,6 @@ using System.Text; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using RankLib.Learning.Tree; using RankLib.Utilities; @@ -10,12 +12,18 @@ namespace RankLib.Learning; public class Combiner { private readonly RankerFactory _rankerFactory; + private readonly ILogger _logger; /// /// Instantiates a new instance of /// /// The ranker factory to use to read ranker models - public Combiner(RankerFactory rankerFactory) => _rankerFactory = rankerFactory; + /// Logger to log messages + public Combiner(RankerFactory rankerFactory, ILogger? logger = null) + { + _rankerFactory = rankerFactory; + _logger = logger ?? NullLogger.Instance; + } /// /// Combines the first from each @@ -43,6 +51,8 @@ public void Combine(string directory, string outputFile) var ensemble = randomForests.Ensembles[0]; writer.Write(ensemble.ToString()); } + else + _logger.LogError("{File} is not a a RandomForests ranker. skipping", file); } } catch (Exception ex) diff --git a/src/RankLib/Learning/DataPoint.cs b/src/RankLib/Learning/DataPoint.cs index b1cc289..bdca858 100644 --- a/src/RankLib/Learning/DataPoint.cs +++ b/src/RankLib/Learning/DataPoint.cs @@ -184,7 +184,7 @@ protected DataPoint(ReadOnlySpan span) /// /// Gets the feature count /// - public int FeatureCount { get; private set; } + public int FeatureCount { get; protected set; } public void ResetCached() => Cached = -1; diff --git a/src/RankLib/Learning/DenseDataPoint.cs b/src/RankLib/Learning/DenseDataPoint.cs index c2f5a99..6e04574 100644 --- a/src/RankLib/Learning/DenseDataPoint.cs +++ b/src/RankLib/Learning/DenseDataPoint.cs @@ -25,6 +25,10 @@ public DenseDataPoint(ReadOnlySpan span) : base(span) { } public DenseDataPoint(float label, string id, float[] featureValues, string? description = null) : base(label, id, featureValues, description) { } + /// + /// Initializes a new instance of from another dense data point. + /// + /// The data point to copy. public DenseDataPoint(DenseDataPoint dataPoint) { Label = dataPoint.Label; @@ -32,6 +36,7 @@ public DenseDataPoint(DenseDataPoint dataPoint) Description = dataPoint.Description; Cached = dataPoint.Cached; FeatureValues = new float[dataPoint.FeatureValues.Length]; + FeatureCount = dataPoint.FeatureCount; Array.Copy(dataPoint.FeatureValues, FeatureValues, dataPoint.FeatureValues.Length); } @@ -45,15 +50,15 @@ public override float GetFeatureValue(int featureId) throw RankLibException.Create($"Error in DenseDataPoint::GetFeatureValue(): requesting unspecified feature, fid={featureId}"); } - return IsUnknown(FeatureValues[featureId]) ? 0 : FeatureValues[featureId]; + var featureValue = FeatureValues[featureId]; + return IsUnknown(featureValue) ? 0 : featureValue; } public override void SetFeatureValue(int featureId, float featureValue) { if (featureId <= 0 || featureId >= FeatureValues.Length) - { throw RankLibException.Create($"Error in DenseDataPoint::SetFeatureValue(): feature (id={featureId}) not found."); - } + FeatureValues[featureId] = featureValue; } diff --git a/src/RankLib/Learning/IRanker.cs b/src/RankLib/Learning/IRanker.cs index 6aa8a29..93b4958 100644 --- a/src/RankLib/Learning/IRanker.cs +++ b/src/RankLib/Learning/IRanker.cs @@ -51,7 +51,7 @@ public interface IRanker Task LearnAsync(); /// - /// Evaluates a datapoint. + /// Evaluates a data point. /// /// The data point. /// The score for the data point diff --git a/src/RankLib/Learning/RankList.cs b/src/RankLib/Learning/RankList.cs index 3f57b8e..1b8ee14 100644 --- a/src/RankLib/Learning/RankList.cs +++ b/src/RankLib/Learning/RankList.cs @@ -6,7 +6,7 @@ namespace RankLib.Learning; /// -/// A list of to be ranked. +/// A list of for ranking. /// [DebuggerDisplay("Count={Count}, FeatureCount={FeatureCount}")] public class RankList : IEnumerable diff --git a/src/RankLib/Learning/SparseDataPoint.cs b/src/RankLib/Learning/SparseDataPoint.cs index bb5aadd..70fb931 100644 --- a/src/RankLib/Learning/SparseDataPoint.cs +++ b/src/RankLib/Learning/SparseDataPoint.cs @@ -2,6 +2,9 @@ namespace RankLib.Learning; +/// +/// A sparse data point +/// public sealed class SparseDataPoint : DataPoint { // Access pattern of the feature values @@ -19,6 +22,10 @@ private enum AccessPattern private int _lastMinId = -1; private int _lastMinPos = -1; + /// + /// Initializes a new instance of from the given span. + /// + /// The span containing data to initialize the instance with public SparseDataPoint(ReadOnlySpan span) : base(span) { } /// @@ -33,16 +40,21 @@ public SparseDataPoint(ReadOnlySpan span) : base(span) { } public SparseDataPoint(float label, string id, float[] featureValues, string? description = null) : base(label, id, featureValues, description) { } - public SparseDataPoint(SparseDataPoint sparseDataPoint) + /// + /// Initializes a new instance of from another sparse data point. + /// + /// The data point to copy. + public SparseDataPoint(SparseDataPoint dataPoint) { - Label = sparseDataPoint.Label; - Id = sparseDataPoint.Id; - Description = sparseDataPoint.Description; - Cached = sparseDataPoint.Cached; - _featureIds = new int[sparseDataPoint._featureIds.Length]; - FeatureValues = new float[sparseDataPoint.FeatureValues.Length]; - Array.Copy(sparseDataPoint._featureIds, 0, _featureIds, 0, sparseDataPoint._featureIds.Length); - Array.Copy(sparseDataPoint.FeatureValues, 0, FeatureValues, 0, sparseDataPoint.FeatureValues.Length); + Label = dataPoint.Label; + Id = dataPoint.Id; + Description = dataPoint.Description; + Cached = dataPoint.Cached; + _featureIds = new int[dataPoint._featureIds.Length]; + FeatureValues = new float[dataPoint.FeatureValues.Length]; + FeatureCount = dataPoint.FeatureCount; + Array.Copy(dataPoint._featureIds, 0, _featureIds, 0, dataPoint._featureIds.Length); + Array.Copy(dataPoint.FeatureValues, 0, FeatureValues, 0, dataPoint.FeatureValues.Length); } private int Locate(int fid) diff --git a/src/RankLib/Learning/Tree/RandomForests.cs b/src/RankLib/Learning/Tree/RandomForests.cs index c064ae8..a45e328 100644 --- a/src/RankLib/Learning/Tree/RandomForests.cs +++ b/src/RankLib/Learning/Tree/RandomForests.cs @@ -34,6 +34,16 @@ public class RandomForestsParameters : IRankerParameters /// public const float DefaultSubSamplingRate = 1.0f; + /// + /// Default number of trees in each bag + /// + public const int DefaultTreeCount = 1; + + /// + /// Default number of tree leaves + /// + public const int DefaultTreeLeavesCount = 100; + private RankerType _rankerType = DefaultRankerType; // Parameters @@ -75,27 +85,27 @@ public RankerType RankerType /// /// Number of trees in each bag /// - public int TreeCount { get; set; } = 1; + public int TreeCount { get; set; } = DefaultTreeCount; /// /// Number of leaves in each tree /// - public int TreeLeavesCount { get; set; } = 100; + public int TreeLeavesCount { get; set; } = DefaultTreeLeavesCount; /// - /// The learning rate, or shrinkage, only matters if > 1 + /// The learning rate, or shrinkage, only matters if is greater than 1 /// - public float LearningRate { get; set; } = 0.1F; + public float LearningRate { get; set; } = LambdaMARTParameters.DefaultLearningRate; /// /// The number of threshold candidates. /// - public int Threshold { get; set; } = 256; + public int Threshold { get; set; } = LambdaMARTParameters.DefaultThreshold; /// /// Minimum leaf support /// - public int MinimumLeafSupport { get; set; } = 1; + public int MinimumLeafSupport { get; set; } = LambdaMARTParameters.DefaultMinimumLeafSupport; /// /// Gets or sets the maximum number of concurrent tasks allowed when splitting up workloads diff --git a/src/RankLib/Learning/Tree/RegressionTree.cs b/src/RankLib/Learning/Tree/RegressionTree.cs index 2e7f889..cce17dd 100644 --- a/src/RankLib/Learning/Tree/RegressionTree.cs +++ b/src/RankLib/Learning/Tree/RegressionTree.cs @@ -24,9 +24,9 @@ public RegressionTree(Split root) _leaves = root.Leaves(); } - public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) + public RegressionTree(int treeLeavesCount, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) { - _nodes = nLeaves; + _nodes = treeLeavesCount; _trainingLabels = labels; _hist = hist; _minLeafSupport = minLeafSupport; diff --git a/src/RankLib/Learning/Tree/Split.cs b/src/RankLib/Learning/Tree/Split.cs index ab2777c..58f34eb 100644 --- a/src/RankLib/Learning/Tree/Split.cs +++ b/src/RankLib/Learning/Tree/Split.cs @@ -59,12 +59,15 @@ public void Set(int featureId, float threshold, double deviance) _deviance = deviance; } + /// + /// Gets the feature histogram + /// public FeatureHistogram? Histogram { get; private set; } /// /// Whether this split is a root split. /// - public bool IsRoot { get; set; } + public bool IsRoot { get; init; } /// /// Gets or sets the left split. @@ -160,7 +163,7 @@ private string GetString(string indent) // Internal functions (ONLY used during learning) //*DO NOT* attempt to call them once the training is done - public async Task TrySplitAsync(double[] trainingLabels, int minLeafSupport) + internal async Task TrySplitAsync(double[] trainingLabels, int minLeafSupport) { if (Histogram is null) throw new InvalidOperationException("Histogram is null");