Skip to content

Commit

Permalink
Merge pull request #2368 from dathere/smarter-pivot
Browse files Browse the repository at this point in the history
feat: an even smarter `pivotp`
  • Loading branch information
jqnatividad authored Dec 22, 2024
2 parents 3f66c50 + 3bcdac0 commit 61af751
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
| [lens](/src/cmd/lens.rs#L2)✨ | Interactively view, search & filter a CSV using the [csvlens](https://github.com/YS-L/csvlens#csvlens) engine.
| <a name="luau_deeplink"></a><br>[luau](/src/cmd/luau.rs#L2) 👑✨<br>📇🌐🔣📚 ![CKAN](docs/images/ckan.png) | Create multiple new computed columns, filter rows, compute aggregations and build complex data pipelines by executing a [Luau](https://luau-lang.org) [0.653](https://github.com/Roblox/luau/releases/tag/0.653) expression/script for every row of a CSV file ([sequential mode](https://github.com/dathere/qsv/blob/bb72c4ef369d192d85d8b7cc6e972c1b7df77635/tests/test_luau.rs#L254-L298)), or using [random access](https://www.webopedia.com/definitions/random-access/) with an index ([random access mode](https://github.com/dathere/qsv/blob/bb72c4ef369d192d85d8b7cc6e972c1b7df77635/tests/test_luau.rs#L367-L415)).<br>Can process a single Luau expression or [full-fledged data-wrangling scripts using lookup tables](https://github.com/dathere/qsv-lookup-tables#example) with discrete BEGIN, MAIN and END sections.<br> It is not just another qsv command, it is qsv's [Domain-specific Language](https://en.wikipedia.org/wiki/Domain-specific_language) (DSL) with [numerous qsv-specific helper functions](https://github.com/dathere/qsv/blob/113eee17b97882dc368b2e65fec52b86df09f78b/src/cmd/luau.rs#L1356-L2290) to build production data pipelines. |
| [partition](/src/cmd/partition.rs#L2)<br>👆 | Partition a CSV based on a column value. |
| [pivotp](/src/cmd/pivotp.rs#L2)✨<br>🚀🐻‍❄️🪄 | Pivot CSV data. |
| [pivotp](/src/cmd/pivotp.rs#L2)✨<br>🚀🐻‍❄️🪄 | Pivot CSV data. Features "smart" aggregation auto-selection based on data type & stats. |
| [pro](/src/cmd/pro.rs#L2) | Interact with the [qsv pro](https://qsvpro.dathere.com) API. |
| [prompt](/src/cmd/prompt.rs#L2)| Open a file dialog to either pick a file as input or save output to a file. |
| [pseudo](/src/cmd/pseudo.rs#L2)<br>🔣👆 | [Pseudonymise](https://en.wikipedia.org/wiki/Pseudonymization) the value of the given column by replacing them with an incremental identifier. |
Expand Down
189 changes: 169 additions & 20 deletions src/cmd/pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The pivot operation consists of:
- One or more index columns (these will be the new rows)
- A column that will be pivoted (this will create the new columns)
- A values column that will be aggregated
- An aggregation function to apply
- An aggregation function to apply. Features "smart" aggregation auto-selection.
For examples, see https://github.com/dathere/qsv/blob/master/tests/test_pivotp.rs.
Usage:
qsv pivotp [options] <on-cols> <input>
Expand Down Expand Up @@ -36,15 +38,21 @@ pivotp options:
median - Median value
count - Count of values
last - Last value encountered
[default: count]
none - No aggregation is done. Raises error if multiple values are in group.
smart - use value column statistics to pick an aggregation.
Will only work if there is one value column, otherwise
it falls back to `first`
smartq - same as smart, but no messages.
[default: smart]
--sort-columns Sort the transposed columns by name. Default is by order of discovery.
--col-separator <arg> The separator in generated column names in case of multiple --values columns.
[default: _]
--validate Validate a pivot by checking the pivot column(s)' cardinality.
--try-parsedates When set, will attempt to parse columns as dates.
--infer-len <arg> Number of rows to scan when inferring schema.
Set to 0 to scan entire file. [default: 10000]
--decimal-comma Use comma as decimal separator.
--decimal-comma Use comma as decimal separator when READING the input.
Note that you will need to specify an alternate --delimiter.
--ignore-errors Skip rows that can't be parsed.
Common options:
Expand All @@ -54,20 +62,24 @@ Common options:
Must be a single character. (default: ,)
"#;

use std::{fs::File, io, io::Write, path::Path};
use std::{fs::File, io, io::Write, path::Path, sync::OnceLock};

use csv::ByteRecord;
use indicatif::HumanCount;
use polars::prelude::*;
use polars_ops::pivot::{pivot_stable, PivotAgg};
use serde::Deserialize;

use crate::{
config::Delimiter,
cmd::stats::StatsData,
config::{Config, Delimiter},
util,
util::{get_stats_records, StatsMode},
CliResult,
};

static STATS_RECORDS: OnceLock<(ByteRecord, Vec<StatsData>)> = OnceLock::new();

#[derive(Deserialize)]
struct Args {
arg_on_cols: String,
Expand Down Expand Up @@ -115,11 +127,10 @@ fn calculate_pivot_metadata(
flag_memcheck: false,
};

let Ok((csv_fields, csv_stats)) =
let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
else {
return Ok(None);
};
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

if csv_stats.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -183,6 +194,113 @@ fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {
Ok(())
}

/// Suggest an appropriate aggregation function based on column statistics
#[allow(clippy::cast_precision_loss)]
fn suggest_agg_function(
args: &Args,
value_cols: &[String],
quiet: bool,
) -> CliResult<Option<PivotAgg>> {
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};

let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}

// Get stats for the value column
let value_col = &value_cols[0];
let field_pos = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == value_col);

if let Some(pos) = field_pos {
let stats = &csv_stats[pos];
let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;

// Suggest aggregation based on field type and statistics
let suggested_agg = match stats.r#type.as_str() {
"NULL" => {
if !quiet {
eprintln!("Info: \"{value_col}\" contains only NULL values");
}
PivotAgg::Count
},
"Integer" | "Float" => {
if stats.nullcount as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count");
}
PivotAgg::Count
} else {
PivotAgg::Sum
}
},
"Date" | "DateTime" => {
if stats.cardinality as f64 / row_count as f64 > 0.9 {
if !quiet {
eprintln!(
"Info: {} column \"{value_col}\" has high cardinality, using First",
stats.r#type
);
}
PivotAgg::First
} else {
if !quiet {
eprintln!(
"Info: \"{value_col}\" is a {} column, using Count",
stats.r#type
);
}
PivotAgg::Count
}
},
_ => {
if stats.cardinality == row_count {
if !quiet {
eprintln!("Info: \"{value_col}\" contains all unique values, using First");
}
PivotAgg::First
} else if stats.cardinality as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" has high cardinality, using Count");
}
PivotAgg::Count
} else {
if !quiet {
eprintln!("Info: \"{value_col}\" is a String column, using Count");
}
PivotAgg::Count
}
},
};

Ok(Some(suggested_agg))
} else {
Ok(None)
}
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

Expand Down Expand Up @@ -226,17 +344,42 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

// Get aggregation function
let agg_fn = if let Some(ref agg) = args.flag_agg {
Some(match agg.to_lowercase().as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
_ => return fail_clierror!("Invalid pivot aggregation function: {agg}"),
})
let lower_agg = agg.to_lowercase();
if lower_agg == "none" {
None
} else {
Some(match lower_agg.as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
"smart" | "smartq" => {
if let Some(value_cols) = &value_cols {
// Try to suggest an appropriate aggregation function
if let Some(suggested_agg) =
suggest_agg_function(&args, value_cols, lower_agg == "smartq")?
{
suggested_agg
} else {
// fallback to first, which always works
PivotAgg::First
}
} else {
// Default to Count if no value columns specified
PivotAgg::Count
}
},
_ => {
return fail_incorrectusage_clierror!(
"Invalid pivot aggregation function: {agg}"
)
},
})
}
} else {
None
};
Expand All @@ -248,6 +391,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
b','
};

if args.flag_decimal_comma && delim == b',' {
return fail_incorrectusage_clierror!(
"You need to specify an alternate --delimiter when using --decimal-comma."
);
}

// Create CSV reader config
let csv_reader = LazyCsvReader::new(&args.arg_input)
.with_has_header(true)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ pivotp_test!(
wrk.assert_success(&mut cmd);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;1;1"]];
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;100;150"]];
assert_eq!(got, expected);
}
);
Expand Down Expand Up @@ -549,7 +549,7 @@ pivotp_test!(
wrk.assert_success(&mut cmd);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;1;1"]];
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;100.5;150.75"]];
assert_eq!(got, expected);
}
);
Expand Down Expand Up @@ -577,8 +577,8 @@ pivotp_test!(
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "2", "1"],
svec!["2023-01-02", "1", "2"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
];
assert_eq!(got, expected);
}
Expand All @@ -604,8 +604,8 @@ pivotp_test!(
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "2", "1"],
svec!["2023-01-02", "1", "2"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
];
assert_eq!(got, expected);
}
Expand Down

0 comments on commit 61af751

Please sign in to comment.