Skip to content

Commit

Permalink
More tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
russcam committed Nov 14, 2024
1 parent 4645dfd commit c127393
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/RankLib.Cli/EvaluateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public EvaluateCommand()
AddOption(new Option<RankerType>(["-ranker", "--ranker"], () => RankerType.CoordinateAscent, "Ranking algorithm to use"));
AddOption(new Option<FileInfo>(["-feature", "--feature-description-input-file"], "Feature description file. list features to be considered by the learner, each on a separate line. If not specified, all features will be used.").ExistingOnly());
AddOption(new Option<string>(["-metric2t", "--train-metric"], () => "ERR@10", "Metric to optimize on the training data. Supports MAP, NDCG@k, DCG@k, P@k, RR@k, ERR@k."));
AddOption(new Option<double?>(["-gmax", "--max-label"], "Highest judged relevance label. It affects the calculation of ERR (default=4, i.e. 5-point scale [0,1,2,3,4] where value used is 2^gmax)"));
AddOption(new Option<double?>(["-gmax", "--max-label"], () => ERRScorer.DefaultMax, "Highest judged relevance label. It affects the calculation of ERR i.e. 5-point scale [0,1,2,3,4] where value used is 2^gmax"));
AddOption(new Option<FileInfo>(["-qrel", "--query-relevance-input-file"], "TREC-style relevance judgment file").ExistingOnly());
AddOption(new Option<bool>(["-missingZero", "--missing-zero"], "Substitute zero for missing feature values rather than throwing an exception."));
AddOption(new Option<FileInfo>(["-validate", "--validate-file"], "Specify if you want to tune your system on the validation data").ExistingOnly());
Expand Down Expand Up @@ -428,7 +428,7 @@ public async Task<int> HandleAsync(EvaluateCommandOptions options, CancellationT

if (trainMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase)
|| (testMetric != null && testMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase)))
logger.LogInformation("Highest relevance label (to compute ERR): {HighRelevanceLabel}", (int)SimpleMath.LogBase2(ERRScorer.DefaultMax));
logger.LogInformation("Highest relevance label (to compute ERR): {HighRelevanceLabel}", options.MaxLabel ?? ERRScorer.DefaultMax);

if (options.QueryRelevanceInputFile != null)
logger.LogInformation("TREC-format relevance judgment (only affects MAP and NDCG scores): {QueryRelevanceJudgementFile}", options.QueryRelevanceInputFile.FullName);
Expand Down Expand Up @@ -554,7 +554,7 @@ await evaluator.EvaluateAsync(
{
logger.LogInformation("Test metric: {TestMetric}", testMetric);
if (testMetric.StartsWith("ERR", StringComparison.OrdinalIgnoreCase))
logger.LogInformation("Highest relevance label (to compute ERR): {HighestRelevanceLabel}", options.MaxLabel ?? (int)SimpleMath.LogBase2(ERRScorer.DefaultMax));
logger.LogInformation("Highest relevance label (to compute ERR): {HighestRelevanceLabel}", options.MaxLabel ?? ERRScorer.DefaultMax);

if (savedModelFiles.Count > 1)
{
Expand Down
9 changes: 4 additions & 5 deletions src/RankLib/Eval/Analyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public Analyzer(ISignificanceTest test, ILogger<Analyzer>? logger = null)
_logger = logger ?? NullLogger<Analyzer>.Instance;
}

public class Result
private class Result
{
public int Status = 0;
public int Win = 0;
Expand Down Expand Up @@ -85,8 +85,7 @@ public void Compare(List<string> targetFiles, string baseFile)

_logger.LogInformation("Overall comparison");
_logger.LogInformation("System\tPerformance\tImprovement\tWin\tLoss\tp-value");

_logger.LogInformation($"{Path.GetFileName(baseFile)} [baseline]\t{basePerformance["all"]:F4}");
_logger.LogInformation("{FileName} [baseline]\t{Performance}", Path.GetFileName(baseFile), basePerformance["all"].ToString("F4"));

for (var i = 0; i < results.Length; i++)
{
Expand Down Expand Up @@ -131,7 +130,7 @@ public void Compare(List<string> targetFiles, string baseFile)
}
}

public Result[] Compare(Dictionary<string, double> basePerformance, List<Dictionary<string, double>> targets)
private Result[] Compare(Dictionary<string, double> basePerformance, List<Dictionary<string, double>> targets)
{
var results = new Result[targets.Count];
for (var i = 0; i < targets.Count; i++)
Expand All @@ -140,7 +139,7 @@ public Result[] Compare(Dictionary<string, double> basePerformance, List<Diction
return results;
}

public Result Compare(Dictionary<string, double> basePerformance, Dictionary<string, double> target)
private Result Compare(Dictionary<string, double> basePerformance, Dictionary<string, double> target)
{
var result = new Result
{
Expand Down
12 changes: 11 additions & 1 deletion src/RankLib/Learning/Combiner.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Text;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using RankLib.Learning.Tree;
using RankLib.Utilities;

Expand All @@ -10,12 +12,18 @@ namespace RankLib.Learning;
public class Combiner
{
private readonly RankerFactory _rankerFactory;
private readonly ILogger<Combiner> _logger;

/// <summary>
/// Instantiates a new instance of <see cref="Combiner"/>
/// </summary>
/// <param name="rankerFactory">The ranker factory to use to read ranker models</param>
public Combiner(RankerFactory rankerFactory) => _rankerFactory = rankerFactory;
/// <param name="logger">Logger to log messages</param>
public Combiner(RankerFactory rankerFactory, ILogger<Combiner>? logger = null)
{
_rankerFactory = rankerFactory;
_logger = logger ?? NullLogger<Combiner>.Instance;
}

/// <summary>
/// Combines the first <see cref="Ensemble"/> from each <see cref="RandomForests"/>
Expand Down Expand Up @@ -43,6 +51,8 @@ public void Combine(string directory, string outputFile)
var ensemble = randomForests.Ensembles[0];
writer.Write(ensemble.ToString());
}
else
_logger.LogError("{File} is not a a RandomForests ranker. skipping", file);
}
}
catch (Exception ex)
Expand Down
2 changes: 1 addition & 1 deletion src/RankLib/Learning/DataPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ protected DataPoint(ReadOnlySpan<char> span)
/// <summary>
/// Gets the feature count
/// </summary>
public int FeatureCount { get; private set; }
public int FeatureCount { get; protected set; }

public void ResetCached() => Cached = -1;

Expand Down
11 changes: 8 additions & 3 deletions src/RankLib/Learning/DenseDataPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ public DenseDataPoint(ReadOnlySpan<char> span) : base(span) { }
public DenseDataPoint(float label, string id, float[] featureValues, string? description = null)
: base(label, id, featureValues, description) { }

/// <summary>
/// Initializes a new instance of <see cref="DenseDataPoint"/> from another dense data point.
/// </summary>
/// <param name="dataPoint">The data point to copy.</param>
public DenseDataPoint(DenseDataPoint dataPoint)
{
Label = dataPoint.Label;
Id = dataPoint.Id;
Description = dataPoint.Description;
Cached = dataPoint.Cached;
FeatureValues = new float[dataPoint.FeatureValues.Length];
FeatureCount = dataPoint.FeatureCount;
Array.Copy(dataPoint.FeatureValues, FeatureValues, dataPoint.FeatureValues.Length);
}

Expand All @@ -45,15 +50,15 @@ public override float GetFeatureValue(int featureId)
throw RankLibException.Create($"Error in DenseDataPoint::GetFeatureValue(): requesting unspecified feature, fid={featureId}");
}

return IsUnknown(FeatureValues[featureId]) ? 0 : FeatureValues[featureId];
var featureValue = FeatureValues[featureId];
return IsUnknown(featureValue) ? 0 : featureValue;
}

public override void SetFeatureValue(int featureId, float featureValue)
{
if (featureId <= 0 || featureId >= FeatureValues.Length)
{
throw RankLibException.Create($"Error in DenseDataPoint::SetFeatureValue(): feature (id={featureId}) not found.");
}

FeatureValues[featureId] = featureValue;
}

Expand Down
2 changes: 1 addition & 1 deletion src/RankLib/Learning/IRanker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public interface IRanker
Task LearnAsync();

/// <summary>
/// Evaluates a datapoint.
/// Evaluates a data point.
/// </summary>
/// <param name="dataPoint">The data point.</param>
/// <returns>The score for the data point</returns>
Expand Down
2 changes: 1 addition & 1 deletion src/RankLib/Learning/RankList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace RankLib.Learning;

/// <summary>
/// A list of <see cref="DataPoint"/> to be ranked.
/// A list of <see cref="DataPoint"/> for ranking.
/// </summary>
[DebuggerDisplay("Count={Count}, FeatureCount={FeatureCount}")]
public class RankList : IEnumerable<DataPoint>
Expand Down
30 changes: 21 additions & 9 deletions src/RankLib/Learning/SparseDataPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

namespace RankLib.Learning;

/// <summary>
/// A sparse data point
/// </summary>
public sealed class SparseDataPoint : DataPoint
{
// Access pattern of the feature values
Expand All @@ -19,6 +22,10 @@ private enum AccessPattern
private int _lastMinId = -1;
private int _lastMinPos = -1;

/// <summary>
/// Initializes a new instance of <see cref="SparseDataPoint"/> from the given span.
/// </summary>
/// <param name="span">The span containing data to initialize the instance with</param>
public SparseDataPoint(ReadOnlySpan<char> span) : base(span) { }

/// <summary>
Expand All @@ -33,16 +40,21 @@ public SparseDataPoint(ReadOnlySpan<char> span) : base(span) { }
public SparseDataPoint(float label, string id, float[] featureValues, string? description = null)
: base(label, id, featureValues, description) { }

public SparseDataPoint(SparseDataPoint sparseDataPoint)
/// <summary>
/// Initializes a new instance of <see cref="SparseDataPoint"/> from another sparse data point.
/// </summary>
/// <param name="dataPoint">The data point to copy.</param>
public SparseDataPoint(SparseDataPoint dataPoint)
{
Label = sparseDataPoint.Label;
Id = sparseDataPoint.Id;
Description = sparseDataPoint.Description;
Cached = sparseDataPoint.Cached;
_featureIds = new int[sparseDataPoint._featureIds.Length];
FeatureValues = new float[sparseDataPoint.FeatureValues.Length];
Array.Copy(sparseDataPoint._featureIds, 0, _featureIds, 0, sparseDataPoint._featureIds.Length);
Array.Copy(sparseDataPoint.FeatureValues, 0, FeatureValues, 0, sparseDataPoint.FeatureValues.Length);
Label = dataPoint.Label;
Id = dataPoint.Id;
Description = dataPoint.Description;
Cached = dataPoint.Cached;
_featureIds = new int[dataPoint._featureIds.Length];
FeatureValues = new float[dataPoint.FeatureValues.Length];
FeatureCount = dataPoint.FeatureCount;
Array.Copy(dataPoint._featureIds, 0, _featureIds, 0, dataPoint._featureIds.Length);
Array.Copy(dataPoint.FeatureValues, 0, FeatureValues, 0, dataPoint.FeatureValues.Length);
}

private int Locate(int fid)
Expand Down
22 changes: 16 additions & 6 deletions src/RankLib/Learning/Tree/RandomForests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ public class RandomForestsParameters : IRankerParameters
/// </summary>
public const float DefaultSubSamplingRate = 1.0f;

/// <summary>
/// Default number of trees in each bag
/// </summary>
public const int DefaultTreeCount = 1;

/// <summary>
/// Default number of tree leaves
/// </summary>
public const int DefaultTreeLeavesCount = 100;

private RankerType _rankerType = DefaultRankerType;

// Parameters
Expand Down Expand Up @@ -75,27 +85,27 @@ public RankerType RankerType
/// <summary>
/// Number of trees in each bag
/// </summary>
public int TreeCount { get; set; } = 1;
public int TreeCount { get; set; } = DefaultTreeCount;

/// <summary>
/// Number of leaves in each tree
/// </summary>
public int TreeLeavesCount { get; set; } = 100;
public int TreeLeavesCount { get; set; } = DefaultTreeLeavesCount;

/// <summary>
/// The learning rate, or shrinkage, only matters if <see cref="TreeCount"/> > 1
/// The learning rate, or shrinkage, only matters if <see cref="TreeCount"/> is greater than 1
/// </summary>
public float LearningRate { get; set; } = 0.1F;
public float LearningRate { get; set; } = LambdaMARTParameters.DefaultLearningRate;

/// <summary>
/// The number of threshold candidates.
/// </summary>
public int Threshold { get; set; } = 256;
public int Threshold { get; set; } = LambdaMARTParameters.DefaultThreshold;

/// <summary>
/// Minimum leaf support
/// </summary>
public int MinimumLeafSupport { get; set; } = 1;
public int MinimumLeafSupport { get; set; } = LambdaMARTParameters.DefaultMinimumLeafSupport;

/// <summary>
/// Gets or sets the maximum number of concurrent tasks allowed when splitting up workloads
Expand Down
4 changes: 2 additions & 2 deletions src/RankLib/Learning/Tree/RegressionTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public RegressionTree(Split root)
_leaves = root.Leaves();
}

public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
public RegressionTree(int treeLeavesCount, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
{
_nodes = nLeaves;
_nodes = treeLeavesCount;
_trainingLabels = labels;
_hist = hist;
_minLeafSupport = minLeafSupport;
Expand Down
7 changes: 5 additions & 2 deletions src/RankLib/Learning/Tree/Split.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ public void Set(int featureId, float threshold, double deviance)
_deviance = deviance;
}

/// <summary>
/// Gets the feature histogram
/// </summary>
public FeatureHistogram? Histogram { get; private set; }

/// <summary>
/// Whether this split is a root split.
/// </summary>
public bool IsRoot { get; set; }
public bool IsRoot { get; init; }

/// <summary>
/// Gets or sets the left split.
Expand Down Expand Up @@ -160,7 +163,7 @@ private string GetString(string indent)

// Internal functions (ONLY used during learning)
//*DO NOT* attempt to call them once the training is done
public async Task<bool> TrySplitAsync(double[] trainingLabels, int minLeafSupport)
internal async Task<bool> TrySplitAsync(double[] trainingLabels, int minLeafSupport)
{
if (Histogram is null)
throw new InvalidOperationException("Histogram is null");
Expand Down

0 comments on commit c127393

Please sign in to comment.