Skip to content

Commit

Permalink
Some llama-run cleanups
Browse files Browse the repository at this point in the history
Use consolidated open function call from File class. Change
read_all to to_string(). Remove exclusive locking, the intent for
that lock is to avoid multiple processes writing to the same file,
it's not an issue for readers, although we may want to consider
adding a shared lock. Remove passing nullptr as reference,
references are never supposed to be null. clang-format the code
for consistent styling.

Signed-off-by: Eric Curtin <[email protected]>
  • Loading branch information
ericcurtin committed Feb 20, 2025
1 parent 4806498 commit 3caa933
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,25 +323,17 @@ class File {
return 0;
}

std::string read_all(const std::string & filename){
open(filename, "r");
lock();
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return "";
}

std::string to_string() {
fseek(file, 0, SEEK_END);
size_t size = ftell(file);
const size_t size = ftell(file);
fseek(file, 0, SEEK_SET);

std::string out;
out.resize(size);
size_t read_size = fread(&out[0], 1, size, file);
const size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
return "";
printe("Error reading file: %s", strerror(errno));
}

return out;
}

Expand Down Expand Up @@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {

// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}

File file;
std::string chat_template = "";
chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
if (!file.open(chat_template_file, "r")) {
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return chat_template;

return file.to_string();
}

static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
const common_chat_templates_ptr & chat_templates, int & prev_len,
const bool stdout_a_terminal) {
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
return 1;
}

std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
return 1;
}

if (!opt.user.empty()) {
return 2;
}

add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
return 1;
}

return 0;
}

// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));

std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
std::string chat_template;
if (!opt.chat_template_file.empty()) {
chat_template = read_chat_template_file(opt.chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);

common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
if (get_user_input(user_input, user) == 1) {
if (get_user_input(user_input, opt.user) == 1) {
return 0;
}

add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1;
}

std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
if (ret == 1) {
return 1;
}

if (!user.empty()) {
} else if (ret == 2) {
break;
}

add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}

return 0;
Expand Down Expand Up @@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
return 1;
}

if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
if (chat_loop(llama_data, opt)) {
return 1;
}

Expand Down

0 comments on commit 3caa933

Please sign in to comment.