Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow comma-separated filenames in multi-chain configurations #1312

Merged
merged 18 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_diagnostic_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_diagnostic_file : public string_argument {
public:
arg_diagnostic_file() : string_argument() {
_name = "diagnostic_file";
_description = "Auxiliary output file for diagnostic information";
_validity = "Path to existing file";
_description
= "Auxiliary output file for diagnostic information. If multiple "
"chains are run, this can either be a single path, in which case its "
"name will have _ID appended, or a comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "\"\"";
_default_value = "";
_value = _default_value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class arg_generate_quantities_fitted_params : public string_argument {
_name = "fitted_params";
_description
= "Input file of sample of fitted parameter values for model "
"conditioned on data";
"conditioned on data. If multiple chains are run, this can either "
"be a single path, in which case its name will have _ID appended, or "
"a comma-separated list of names.";
_validity = "Path to existing file";
_default = "\"\"";
_default_value = "";
Expand Down
6 changes: 5 additions & 1 deletion src/cmdstan/arguments/arg_init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ class arg_init : public string_argument {
_description = std::string("Initialization method: ")
+ std::string("\"x\" initializes randomly between [-x, x], ")
+ std::string("\"0\" initializes to 0, ")
+ std::string("anything else identifies a file of values");
+ std::string(
"anything else identifies a file of values. If "
"multiple chains are run, this can either be a single "
"path, in which case its name will have _ID appended, "
"or a comma-separated list of names.");
_default = "\"2\"";
_default_value = "2";
_value = _default_value;
Expand Down
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_output_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_output_file : public string_argument {
public:
arg_output_file() : string_argument() {
_name = "file";
_description = "Output file";
_validity = "Path to existing file";
_description
= "Output file. If multiple chains are run, this can either be a "
"single path, in which case its name will have _ID appended, or a "
"comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "output.csv";
_default_value = "output.csv";
_value = _default_value;
Expand Down
15 changes: 10 additions & 5 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,17 @@ int command(int argc, const char *argv[]) {
"Missing fitted_params argument, cannot run generate_quantities "
"without fitted sample.");
}
auto file_info = file::get_basename_suffix(fname);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}

std::vector<std::string> fname_vec
= file::make_filenames(file_info.first, "", ".csv", num_chains, id);
= file::make_filenames(fname, "", ".csv", num_chains, id);

for (auto &f : fname_vec) {
auto file_info = file::get_basename_suffix(f);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}
}

std::vector<std::string> param_names = get_constrained_param_names(model);
std::vector<Eigen::MatrixXd> fitted_params_vec;
fitted_params_vec.reserve(num_chains);
Expand Down
99 changes: 54 additions & 45 deletions src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,57 +147,74 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains,
if (num_chains == 1) {
return context_vector(1, get_var_context(file));
}
auto make_context = [](auto &&file, auto &&stream,
auto &&file_ending) -> shared_context_ptr {
auto make_context = [](auto &&file, auto &&stream) -> shared_context_ptr {
auto [file_name, file_ending] = file::get_basename_suffix(file);
if (file_ending == ".json") {
using stan::json::json_data;
return std::make_shared<json_data>(json_data(stream));
} else if (file_ending == ".R") {
using stan::io::dump;
return std::make_shared<stan::io::dump>(dump(stream));
} else {
// should never happen, caught by check_valid_file below
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
using stan::io::dump;
return std::make_shared<dump>(dump(stream));
}
};
// use default for all chain inits
if (file.empty()) {
return context_vector(num_chains,
std::make_shared<stan::io::empty_var_context>());
} else {
size_t file_marker_pos = file.find_last_of(".");
if (file_marker_pos > file.size()) {
std::stringstream msg;
msg << "Found: \"" << file
<< "\" but user specified files must end in .json or .R";
throw std::invalid_argument(msg.str());
}
std::string file_name = file.substr(0, file_marker_pos);
std::string file_ending = file.substr(file_marker_pos, file.size());
if (file_ending != ".json" && file_ending != ".R") {
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
}
if (file_ending != ".json") {
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
auto check_valid_file = [](const std::string &file) {
auto [file_name, file_ending] = file::get_basename_suffix(file);

if (file_ending.empty()) {
std::stringstream msg;
msg << "Found: \"" << file
<< "\" but user specified files must end in .json or .R";
throw std::invalid_argument(msg.str());
}
if (file_ending != ".json" && file_ending != ".R") {
std::stringstream msg;
msg << "file ending of " << file_ending
<< " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
}
if (file_ending != ".json") {
std::cerr << "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new "
"features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
}
};

bool has_commas = file.find(',') != std::string::npos;
auto filenames = file::make_filenames(file, "", "", num_chains, id);

if (has_commas) {
for (auto &&file_name : filenames) {
check_valid_file(file_name);
}
} else {
check_valid_file(file);
}

auto filenames
= file::make_filenames(file_name, "", file_ending, num_chains, id);
auto &file_1 = filenames[0];
std::fstream stream_1(file_1.c_str(), std::fstream::in);
// if file_1 exists we'll assume num_chains of these files exist
if (stream_1.rdstate() & std::ifstream::failbit) {
// if that fails we will try to find a base file
// if we were given a user-specified list, this is an error
if (has_commas) {
std::stringstream msg;
msg << "cannot open \"" << file_1 << "\"" << std::endl;
throw std::invalid_argument(msg.str());
}

// otherwise, we will try to find a base file and use n copies of it
std::fstream stream(file.c_str(), std::fstream::in);
if (stream.rdstate() & std::ifstream::failbit) {
std::string file_name_err
Expand All @@ -208,40 +225,32 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains,
<< std::endl;
throw std::invalid_argument(msg.str());
} else {
return context_vector(num_chains,
make_context(file, stream, file_ending));
return context_vector(num_chains, make_context(file, stream));
}
} else {
// If we found file_1 then we'll assume file_{1...N} exists
context_vector ret;
ret.reserve(num_chains);
ret.push_back(make_context(file_1, stream_1, file_ending));
ret.push_back(make_context(file_1, stream_1));
for (size_t i = 1; i < num_chains; ++i) {
auto &file_i = filenames[i];
std::fstream stream_i(file_i.c_str(), std::fstream::in);
// If any stream fails here something went wrong with file names
if (stream_i.rdstate() & std::ifstream::failbit) {
std::string file_name_err = std::string(
"\"" + file_1 + "\" but cannot open \"" + file_i + "\"");
std::stringstream msg;
msg << "Found " << file_name_err << std::endl;
if (!has_commas) {
// in this case, we generated a template from the given name
msg << "Given the template \"" << file << "\", found \"" << file_1
<< "\" but ";
}
msg << "cannot open \"" << file_i << "\"" << std::endl;
throw std::invalid_argument(msg.str());
}
ret.push_back(make_context(file_i, stream_i, file_ending));
ret.push_back(make_context(file_i, stream_i));
}
return ret;
}
}
// This should not happen
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
using stan::io::dump;
std::fstream stream(file.c_str(), std::fstream::in);
return context_vector(num_chains, std::make_shared<dump>(dump(stream)));
}

/**
Expand Down
62 changes: 48 additions & 14 deletions src/cmdstan/file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,57 @@ std::vector<std::string> make_filenames(const std::string &filename,
const std::string &type,
unsigned int num_chains,
unsigned int id) {
std::pair<std::string, std::string> base_sfx;
base_sfx = get_basename_suffix(filename);
if (type != ".csv" || base_sfx.second.empty()) {
base_sfx.second = type;
}

std::vector<std::string> names(num_chains);
auto name_iterator = [num_chains, id](auto i) {
if (num_chains == 1) {
return std::string("");
} else {
return std::string("_" + std::to_string(i + id));

// if a ',' is present, we assume the user fully specified the names
if (filename.find(',') != std::string::npos) {
std::vector<std::string> filenames;
boost::algorithm::split(filenames, filename, boost::is_any_of(","),
boost::token_compress_on);
if (filenames.size() != num_chains) {
std::stringstream msg;
msg << "Number of filenames does not match number of chains: got "
"comma-separated list '"
<< filename << "' of length " << filenames.size() << " but expected "
<< num_chains << " names" << std::endl;
throw std::invalid_argument(msg.str());
}

std::transform(filenames.cbegin(), filenames.cend(), names.begin(),
[&tag, &type](const std::string &name) {
auto [base_name, sfx] = get_basename_suffix(name);
if ((!type.empty() && type != ".csv") || sfx.empty()) {
sfx = type;
}
// TODO: in most cases tag is empty, it would be nice if it
// was never used for maximum user control
return base_name + tag + sfx;
});
} else {
// otherwise, this is a template which gets edited like output.csv ->
// output_1.csv
auto [base_name, sfx] = get_basename_suffix(filename);

// first condition here is legacy -- we used to be very lax
// about file names, but with things like json outputs
// we need to be stricter to avoid collisions, so we only
// allow laxity on the suffix for intended-to-be csv files
if ((!type.empty() && type != ".csv") || sfx.empty()) {
sfx = type;
}

auto name_iterator = [num_chains, id](auto i) {
if (num_chains == 1) {
return std::string("");
} else {
return std::string("_" + std::to_string(i + id));
}
};
for (int i = 0; i < num_chains; ++i) {
names[i] = base_name + tag + name_iterator(i) + sfx;
}
};
for (int i = 0; i < num_chains; ++i) {
names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second;
}

return names;
}

Expand Down
25 changes: 25 additions & 0 deletions src/test/interface/file_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,31 @@ TEST(CommandHelper, make_filenames) {
EXPECT_EQ(names.size(), num_chains);
EXPECT_EQ(names[0], expect_n4_0);
EXPECT_EQ(names[1], expect_n4_1);

// comma-separated list of names
std::string fp3 = "foo" + sep + "bar" + sep + "baz.boz,foo" + sep + "bar"
+ sep + "gak.boz";
names.clear();
names.resize(0);
names = make_filenames(fp3, "_mu", ".csv", num_chains, id);
EXPECT_EQ(names.size(), num_chains);
EXPECT_EQ(names[0], "foo" + sep + "bar" + sep + "baz_mu.boz");
EXPECT_EQ(names[1], "foo" + sep + "bar" + sep + "gak_mu.boz");

// comma-separated list of names, non-csv
names.clear();
names.resize(0);
names = make_filenames(fp3, "_mu", ".json", num_chains, id);
EXPECT_EQ(names.size(), num_chains);
EXPECT_EQ(names[0], "foo" + sep + "bar" + sep + "baz_mu.json");
EXPECT_EQ(names[1], "foo" + sep + "bar" + sep + "gak_mu.json");

// comma-separated list (incorrect length)
std::string fp4 = fp3 + ",foo" + sep + "bar" + sep + "flux.boz";
names.clear();
names.resize(0);
EXPECT_THROW(names = make_filenames(fp4, "_mu", ".csv", num_chains, id),
std::invalid_argument);
}

TEST(CommandHelper, check_filename_config_good) {
Expand Down
13 changes: 13 additions & 0 deletions src/test/interface/generated_quantities_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ TEST_F(CmdStan, generate_quantities_good_multi) {
ASSERT_FALSE(out.hasError);
}

TEST_F(CmdStan, generate_quantities_good_multi_comma) {
std::stringstream ss;
ss << convert_model_path(bern_gq_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=generate_quantities fitted_params="
<< convert_model_path(bern_fitted_params) << ","
<< convert_model_path(bern_fitted_params) << " num_chains=2";
std::string cmd = ss.str();
run_command_output out = run_command(cmd);
ASSERT_FALSE(out.hasError) << out.output;
}

TEST_F(CmdStan, generate_quantities_same_in_out_multi) {
std::stringstream ss;
ss << convert_model_path(bern_gq_model)
Expand Down
Loading