Skip to content

Commit

Permalink
Add basic argument parser for testing the obfuscator
Browse files Browse the repository at this point in the history
  • Loading branch information
mrexodia committed Oct 11, 2024
1 parent ff74c2c commit d1a2f37
Show file tree
Hide file tree
Showing 2 changed files with 315 additions and 8 deletions.
278 changes: 278 additions & 0 deletions obfuscator/include/obfuscator/args.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
#pragma once

#include <string>
#include <vector>
#include <functional>
#include <unordered_set>

class ArgumentParser
{
protected:
void addPositional(const std::string& name, std::string& value, const std::string& help, bool required = false)
{
auto fn = [this, &value]()
{
value = arg;
};
positionalArgs.push_back(Arg{name, help, required, fn});
}

void addString(const std::string& flagname, std::string& value, const std::string& help, bool required = false)
{
auto fn = [this, flagname, &value]
{
if (arg.substr(0, flagname.length()) == flagname)
{
if (arg.length() == flagname.length())
{
// -flagname <value>
if (i + 1 >= argc)
{
throw std::runtime_error("missing value for '" + flagname + "' argument");
}
value = argv[++i];
if (value.empty())
{
throw std::runtime_error("empty value for '" + flagname + "' argument");
}
markExtracted(flagname);
}
else if (arg[flagname.length()] == '=')
{
// -flagname=<value>
value = arg.substr(flagname.length() + 1);
markExtracted(flagname);
}
}
};
flagArgs.push_back(Arg{flagname, help, required, fn});
}

void addBool(const std::string& flagname, bool& value, const std::string& help, bool required = false)
{
auto fn = [this, flagname, &value]
{
if (arg.substr(0, flagname.length()) == flagname)
{
if (arg.length() == flagname.length())
{
// -flagname
value = true;
markExtracted(flagname);
}
else if (arg[flagname.length()] == '=')
{
// -flagname=<value>
auto strValue = arg.substr(flagname.length() + 1);
if (strValue.empty())
{
throw std::runtime_error("empty value for '" + flagname + "' argument");
}
value = strValue == "1" || strValue == "true";
markExtracted(flagname);
}
}
};
flagArgs.push_back(Arg{flagname, help, required, fn});
}

public:
explicit ArgumentParser(std::string description) : description(std::move(description))
{
}

virtual ~ArgumentParser() = default;
ArgumentParser(const ArgumentParser&) = delete;
ArgumentParser& operator=(const ArgumentParser&) = delete;
ArgumentParser(ArgumentParser&&) = delete;
ArgumentParser& operator=(ArgumentParser&&) = delete;

void parse(int argc, char** argv)
{
this->argc = argc;
this->argv = argv;
bool seenRequired = false;
for (const auto& positionalArg : positionalArgs)
{
if (positionalArg.name.empty())
{
throw std::runtime_error("cannot add positional argument without name");
}
if (!positionalArg.required)
{
if (seenRequired)
{
throw std::runtime_error("cannot add required positional argument after an optional one");
}
}
else
{
seenRequired = true;
}
}
for (const auto& flagArg : flagArgs)
{
if (flagArg.name.empty())
{
throw std::runtime_error("cannot add argument without name");
}
if (flagArg.name[0] != '-')
{
throw std::runtime_error("invalid argument name '" + flagArg.name + "'");
}
}
size_t positionalIndex = 0;
for (i = 1; i < argc; i++)
{
arg = std::string(argv[i]);
if (arg.empty())
{
continue;
}
if (arg[0] == '-')
{
didExtract = false;
for (const auto& flag : flagArgs)
{
flag.fn();
}
if (!didExtract)
{
throw std::runtime_error("unknown argument '" + arg + "'");
}
}
else
{
if (positionalIndex + 1 > positionalArgs.size())
{
throw std::runtime_error("unexpected positional argument '" + arg + "'");
}
const auto& positionalArg = positionalArgs[positionalIndex++];
if (positionalArg.name[0] == '-')
{
markExtracted(positionalArg.name);
}
positionalArg.fn();
}
}
for (const auto& flagArg : flagArgs)
{
if (!flagArg.required)
{
continue;
}
if (!flagsExtracted.contains(flagArg.name))
{
throw std::runtime_error("required argument '" + flagArg.name + "' missing");
}
}
for (size_t i = positionalIndex; i < positionalArgs.size(); i++)
{
const auto& positionalArg = positionalArgs[i];
if (positionalArg.required)
{
if (flagsExtracted.contains(positionalArg.name))
{
continue;
}
throw std::runtime_error("required positional argument missing");
}
}
}

[[nodiscard]] std::string helpStr() const
{
std::string help;
help += " ";
help += argv[0];
help += " {OPTIONS}";

for (const auto& positionalArg : positionalArgs)
{
help += " ";
if (!positionalArg.required)
{
help += '[';
}
if (positionalArg.name[0] == '-')
{
help += "[" + positionalArg.name + "]";
help += " <value>";
}
else
{
help += positionalArg.name;
}
if (!positionalArg.required)
{
help += ']';
}
}
help += '\n';

if (!description.empty())
{
help += "\n ";
help += description;
help += "\n\n";
}

help += " OPTIONS:\n";

size_t maxLen = 0;
for (const auto& flagArg : flagArgs)
{
if (flagArg.name.size() > maxLen)
{
maxLen = flagArg.name.size();
}
}
for (const auto& flagArg : flagArgs)
{
help += "\n ";
help += flagArg.name;
for (size_t i = 0; i < maxLen - flagArg.name.size(); i++)
{
help += ' ';
}
help += " ";
help += flagArg.help;
if (!flagArg.required)
{
help += " (optional)";
}
}

return help;
}

private:
struct Arg
{
std::string name;
std::string help;
bool required = false;
std::function<void()> fn;
};

std::string description;
std::vector<Arg> positionalArgs;
std::vector<Arg> flagArgs;

int i = 1;
int argc = 0;
char** argv = nullptr;
bool didExtract = false;
std::string arg;
std::unordered_set<std::string> flagsExtracted;

void markExtracted(const std::string& flagname)
{
didExtract = true;
if (flagsExtracted.contains(flagname))
{
throw std::runtime_error("duplicate value for '" + flagname + "' argument");
}
flagsExtracted.insert(flagname);
}
};
45 changes: 37 additions & 8 deletions obfuscator/src/obfuscate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,44 @@ static bool riscvm_handle_syscall(vm::riscvm* self, uint64_t code, uint64_t* res

#endif // _WIN32

int main(int argc, char** argv)
#include <obfuscator/args.hpp>

struct Arguments : ArgumentParser
{
if (argc < 2)
std::string input;
std::string output;
std::string payload;
bool help;

Arguments(int argc, char** argv) : ArgumentParser("Obfuscates the riscvm_run function")
{
puts("Usage: obfuscator riscvm.exe [payload.bin]");
return EXIT_FAILURE;
addPositional("input", input, "Input PE file to obfuscate", true);
addString("-output", output, "Obfuscated function output");
addString("-payload", payload, "Payload to execute (Windows only)");
addBool("-help", help, "Prints this help message");
try
{
parse(argc, argv);
}
catch (const std::exception& e)
{
printf("Error: %s\n\nHelp:\n%s\n", e.what(), helpStr().c_str());
std::exit(help ? EXIT_SUCCESS : EXIT_FAILURE);
}
if (help)
{
puts(helpStr().c_str());
std::exit(EXIT_SUCCESS);
}
}
};

int main(int argc, char** argv)
{
Arguments args(argc, argv);

std::vector<uint8_t> pe;
if (!loadFile(argv[1], pe))
if (!loadFile(args.input, pe))
{
puts("Failed to load the executable.");
return EXIT_FAILURE;
Expand Down Expand Up @@ -205,8 +233,9 @@ int main(int argc, char** argv)
auto size = serializer.getCodeSize();

// Save the obfuscated code to disk
if (!args.output.empty())
{
std::ofstream ofs("riscvm_run_obfuscated.bin", std::ios::binary);
std::ofstream ofs(args.output, std::ios::binary);
ofs.write((char*)ptr, size);
}

Expand All @@ -226,10 +255,10 @@ int main(int argc, char** argv)
__debugbreak();

// Run the payload if specified on the command line
if (argc > 2)
if (!args.payload.empty())
{
std::vector<uint8_t> payload;
if (!loadFile(argv[2], payload))
if (!loadFile(args.payload, payload))
{
puts("Failed to load the payload.");
return EXIT_FAILURE;
Expand Down

0 comments on commit d1a2f37

Please sign in to comment.