Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow comma-separated filenames in multi-chain configurations #1312

Merged
merged 18 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions make/command
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
src/cmdstan/stansummary.d: DEPTARGETS = -MT bin/cmdstan/stansummary.o
# don't build anything during a `make clean`
ifneq ($(MAKECMDGOALS),)
ifeq ($(filter clean%,$(MAKECMDGOALS)),)
-include src/cmdstan/stansummary.d
endif
endif

bin/cmdstan/%.o : src/cmdstan/%.cpp
@mkdir -p $(dir $@)
Expand Down
14 changes: 9 additions & 5 deletions make/tests
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ test/%$(EXE) : INC += $(INC_GTEST) -I $(RAPIDJSON)
test/%$(EXE) : test/%.o $(GTEST)/src/gtest_main.cc $(GTEST)/src/gtest-all.o $(SUNDIALS_TARGETS) $(MPI_TARGETS) $(TBB_TARGETS)
$(LINK.cpp) $(filter-out src/test/test-models/% src/%.csv bin/% test/%.hpp %.hpp-test,$^) $(LDLIBS) $(OUTPUT_OPTION)

test/%.o : src/test/%.cpp
.PRECIOUS: test/%.o
test/%.o : src/test/%.cpp src/test/utility.hpp $(wildcard src/cmdstan/*.hpp)
@mkdir -p $(dir $@)
$(COMPILE.cpp) $< $(OUTPUT_OPTION)

Expand All @@ -23,7 +24,6 @@ src/test/%.d : test/%.o

ifneq ($(filter test/%,$(MAKECMDGOALS)),)
-include $(patsubst test/%$(EXE),src/test/%.d,$(filter test/%,$(MAKECMDGOALS)))
-include $(patsubst %.cpp,%.d,$(STANC_TEMPLATE_INSTANTIATION_CPP))
endif

############################################################
Expand Down Expand Up @@ -56,10 +56,14 @@ test-headers: $(HEADER_TESTS)
##
TEST_MODELS := $(wildcard src/test/test-models/*.stan)

ifneq ($(filter test-models-hpp,$(MAKECMDGOALS)),)
-include $(patsubst %.stan,%.d,$(TEST_MODELS))
include src/cmdstan/main.d
endif

.PHONY: test-models-hpp
test-models-hpp:
$(MAKE) $(patsubst %.stan,%.hpp,$(TEST_MODELS))
$(MAKE) $(patsubst %.stan,%$(EXE),$(TEST_MODELS))
test-models-hpp: $(patsubst %.stan,%.hpp,$(TEST_MODELS)) $(patsubst %.stan,%$(EXE),$(TEST_MODELS))

##
# Tests that depend on compiled models
##
Expand Down
8 changes: 8 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ build-mpi: $(MPI_TARGETS)
@echo ''
@echo '--- boost mpi bindings built ---'

# don't build anything during a `make clean`
# but otherwise, we always want to check main.d
ifneq ($(MAKECMDGOALS),)
ifeq ($(filter clean%,$(MAKECMDGOALS)),)
include src/cmdstan/main.d
endif
endif

.PHONY: build
build: bin/stanc$(EXE) $(SUNDIALS_TARGETS) $(MPI_TARGETS) $(TBB_TARGETS) $(CMDSTAN_MAIN_O) $(PRECOMPILED_MODEL_HEADER) bin/stansummary$(EXE) bin/print$(EXE) bin/diagnose$(EXE)
@echo ''
Expand Down
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_diagnostic_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_diagnostic_file : public string_argument {
public:
arg_diagnostic_file() : string_argument() {
_name = "diagnostic_file";
_description = "Auxiliary output file for diagnostic information";
_validity = "Path to existing file";
_description
= "Auxiliary output file for diagnostic information. If multiple "
"chains are run, this can either be a single path, in which case its "
"name will have _ID appended, or a comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "\"\"";
_default_value = "";
_value = _default_value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class arg_generate_quantities_fitted_params : public string_argument {
_name = "fitted_params";
_description
= "Input file of sample of fitted parameter values for model "
"conditioned on data";
"conditioned on data. If multiple chains are run, this can either "
"be a single path, in which case its name will have _ID appended, or "
"a comma-separated list of names.";
_validity = "Path to existing file";
_default = "\"\"";
_default_value = "";
Expand Down
6 changes: 5 additions & 1 deletion src/cmdstan/arguments/arg_init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ class arg_init : public string_argument {
_description = std::string("Initialization method: ")
+ std::string("\"x\" initializes randomly between [-x, x], ")
+ std::string("\"0\" initializes to 0, ")
+ std::string("anything else identifies a file of values");
+ std::string(
"anything else identifies a file of values. If "
"multiple chains are run, this can either be a single "
"path, in which case its name will have _ID appended, "
"or a comma-separated list of names.");
_default = "\"2\"";
_default_value = "2";
_value = _default_value;
Expand Down
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_output_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_output_file : public string_argument {
public:
arg_output_file() : string_argument() {
_name = "file";
_description = "Output file";
_validity = "Path to existing file";
_description
= "Output file. If multiple chains are run, this can either be a "
"single path, in which case its name will have _ID appended, or a "
"comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "output.csv";
_default_value = "output.csv";
_value = _default_value;
Expand Down
15 changes: 10 additions & 5 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,17 @@ int command(int argc, const char *argv[]) {
"Missing fitted_params argument, cannot run generate_quantities "
"without fitted sample.");
}
auto file_info = file::get_basename_suffix(fname);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}

std::vector<std::string> fname_vec
= file::make_filenames(file_info.first, "", ".csv", num_chains, id);
= file::make_filenames(fname, "", ".csv", num_chains, id);

for (auto &f : fname_vec) {
auto file_info = file::get_basename_suffix(f);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}
}

std::vector<std::string> param_names = get_constrained_param_names(model);
std::vector<Eigen::MatrixXd> fitted_params_vec;
fitted_params_vec.reserve(num_chains);
Expand Down
190 changes: 101 additions & 89 deletions src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,35 @@ inline constexpr auto get_arg_val(List &&arg_list, Args &&... args) {
}
}

/**
* Check that a file is either a .json or .R (dump) extension
*/
inline void check_valid_context_file_name(const std::string &file) {
auto [file_name, file_ending] = file::get_basename_suffix(file);
if (file_ending != ".json") {
if (file_ending != ".R") {
std::stringstream msg;
msg << "User specified files must end in .json or .R. Found: ";
if (file_ending.empty()) {
msg << file;
} else {
msg << file_ending;
}
msg << std::endl;
throw std::invalid_argument(msg.str());
}

std::cerr << "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new "
"features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
}
}

using shared_context_ptr = std::shared_ptr<stan::io::var_context>;

/**
* Given the name of a file, return a shared pointer holding the data contents.
* @param file A system file to read from
Expand All @@ -117,131 +145,115 @@ inline shared_context_ptr get_var_context(const std::string &file) {
if (file.empty()) {
return std::make_shared<stan::io::empty_var_context>();
}
check_valid_context_file_name(file);
std::ifstream stream = file::safe_open(file);
if (file::get_suffix(file) == ".json") {
stan::json::json_data var_context(stream);
return std::make_shared<stan::json::json_data>(var_context);
}
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
stan::io::dump var_context(stream);
return std::make_shared<stan::io::dump>(var_context);
}

using context_vector = std::vector<shared_context_ptr>;
/**
* Make a vector of shared pointers to contexts.
* @param file The name of the file. For multi-chain we will attempt to find
* {file_name}_1{file_ending} and if that fails try to use the named file as
* @param file The name of the file. For multi-chain, this can either be a
* comma-separated list, or else we will attempt to find
* {file_name}_{id}{file_ending} and if that fails try to use the named file as
* the data for each chain.
* @param num_chains The number of chains to run
* @param id The id of the first chain
* @return a std vector of shared pointers to var contexts
*/
context_vector get_vec_var_context(const std::string &file, size_t num_chains,
unsigned int id) {
using stan::io::var_context;
// simple handling for 1 chain
if (num_chains == 1) {
return context_vector(1, get_var_context(file));
}
auto make_context = [](auto &&file, auto &&stream,
auto &&file_ending) -> shared_context_ptr {
// use default for all chain inits
if (file.empty()) {
return context_vector(num_chains,
std::make_shared<stan::io::empty_var_context>());
}

const bool has_commas = file.find(',') != std::string::npos;
auto filenames = file::make_filenames(file, "", "", num_chains, id);

std::vector<std::string> missing_files;
std::vector<std::fstream> streams;
streams.reserve(num_chains);

// check files are valid and exist, or build up a list of the missing ones
for (auto &&file_name : filenames) {
check_valid_context_file_name(file_name);
std::fstream stream(file_name.c_str(), std::fstream::in);
if (stream.rdstate() & std::ifstream::failbit) {
missing_files.push_back(file_name);
}
streams.push_back(std::move(stream));
}

auto make_context = [](auto &&file, auto &&stream) -> shared_context_ptr {
auto [file_name, file_ending] = file::get_basename_suffix(file);
if (file_ending == ".json") {
using stan::json::json_data;
return std::make_shared<json_data>(json_data(stream));
} else if (file_ending == ".R") {
using stan::io::dump;
return std::make_shared<stan::io::dump>(dump(stream));
return std::make_shared<dump>(dump(stream));

} else {
// should never happen, caught by check_valid_context_file_name above
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
using stan::io::dump;
return std::make_shared<dump>(dump(stream));
}
};
// use default for all chain inits
if (file.empty()) {
return context_vector(num_chains,
std::make_shared<stan::io::empty_var_context>());
} else {
size_t file_marker_pos = file.find_last_of(".");
if (file_marker_pos > file.size()) {
std::stringstream msg;
msg << "Found: \"" << file
<< "\" but user specified files must end in .json or .R";
throw std::invalid_argument(msg.str());
}
std::string file_name = file.substr(0, file_marker_pos);
std::string file_ending = file.substr(file_marker_pos, file.size());
if (file_ending != ".json" && file_ending != ".R") {
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
}
if (file_ending != ".json") {
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
}

auto filenames
= file::make_filenames(file_name, "", file_ending, num_chains, id);
auto &file_1 = filenames[0];
std::fstream stream_1(file_1.c_str(), std::fstream::in);
// if file_1 exists we'll assume num_chains of these files exist
if (stream_1.rdstate() & std::ifstream::failbit) {
// if that fails we will try to find a base file
std::fstream stream(file.c_str(), std::fstream::in);
if (stream.rdstate() & std::ifstream::failbit) {
std::string file_name_err
= std::string("\"" + file_1 + "\" and base file \"" + file + "\"");
std::stringstream msg;
msg << "Searching for \"" << file_name_err << std::endl;
msg << "Can't open either of specified files," << file_name_err
<< std::endl;
throw std::invalid_argument(msg.str());
} else {
return context_vector(num_chains,
make_context(file, stream, file_ending));
}
} else {
// If we found file_1 then we'll assume file_{1...N} exists
context_vector ret;
ret.reserve(num_chains);
ret.push_back(make_context(file_1, stream_1, file_ending));
for (size_t i = 1; i < num_chains; ++i) {
auto &file_i = filenames[i];
std::fstream stream_i(file_i.c_str(), std::fstream::in);
// If any stream fails here something went wrong with file names
if (stream_i.rdstate() & std::ifstream::failbit) {
std::string file_name_err = std::string(
"\"" + file_1 + "\" but cannot open \"" + file_i + "\"");
std::stringstream msg;
msg << "Found " << file_name_err << std::endl;
throw std::invalid_argument(msg.str());
}
ret.push_back(make_context(file_i, stream_i, file_ending));
}
return ret;
}
// happy path - all files exist and we can return the contexts
if (missing_files.empty()) {
context_vector ret(num_chains);
std::transform(filenames.cbegin(), filenames.cend(), streams.begin(),
ret.begin(), make_context);
return ret;
}

// user directly specified a list of files, some of which don't exist
if (has_commas && !missing_files.empty()) {
std::stringstream msg;
msg << "Cannot open some of the requested files: [";
msg << boost::algorithm::join(missing_files, ", ");
msg << "]" << std::endl;
throw std::invalid_argument(msg.str());
}
// This should not happen
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
using stan::io::dump;

// legacy -- if the user requested 'init.json', we looked for 'init_1.json'
// but if that fails, we try 'init.json' as well
std::fstream stream(file.c_str(), std::fstream::in);
return context_vector(num_chains, std::make_shared<dump>(dump(stream)));
if (stream.rdstate() & std::ifstream::failbit) {
std::stringstream msg;
msg << "Cannot open some of the requested files: [";
msg << boost::algorithm::join(missing_files, ", ");
msg << "]" << std::endl;
msg << "Also failed to find base file " << file << std::endl;
msg << "When cmdstan is given a file 'name' and there are "
"multiple chains or pathfinders,"
" cmdstan will look for files 'name_{N..(N + "
"num_processes)' where N is the id (typically, 1)."
" If these are not found, then it looks for the exact "
"file name as passed."
" In this case, neither option was found.";

throw std::invalid_argument(msg.str());
} else {
std::cerr << "Warning: file '" << file
<< "' is being used to initialize all " << num_chains
<< " chains!" << std::endl;
return context_vector(num_chains, make_context(file, std::move(stream)));
}
}

/**
Expand Down
Loading