From 43792dadd4ad143d79ea7f9c72cebd601bfb023c Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Thu, 2 Jan 2025 20:34:47 +0000 Subject: [PATCH] [Rust] Initial implementation for Rust Signed-off-by: Arthur Chan --- data_prep/introspector.py | 23 +++++-- data_prep/project_src.py | 8 +++ experiment/benchmark.py | 7 ++ experiment/builder_runner.py | 15 +++-- llm_toolkit/prompt_builder.py | 93 +++++++++++++++++++++++++++ prompts/template_xml/rust_base.txt | 3 + prompts/template_xml/rust_problem.txt | 33 ++++++++++ run_one_experiment.py | 6 ++ 8 files changed, 177 insertions(+), 11 deletions(-) create mode 100644 prompts/template_xml/rust_base.txt create mode 100644 prompts/template_xml/rust_problem.txt diff --git a/data_prep/introspector.py b/data_prep/introspector.py index 6d7a04285..835a6abb1 100755 --- a/data_prep/introspector.py +++ b/data_prep/introspector.py @@ -57,7 +57,7 @@ INTROSPECTOR_ORACLE_FAR_REACH = '' INTROSPECTOR_ORACLE_KEYWORD = '' INTROSPECTOR_ORACLE_EASY_PARAMS = '' -INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES = '' +INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = '' INTROSPECTOR_ORACLE_OPTIMAL = '' INTROSPECTOR_ORACLE_ALL_TESTS = '' INTROSPECTOR_FUNCTION_SOURCE = '' @@ -90,6 +90,7 @@ def get_oracle_dict() -> Dict[str, Any]: 'jvm-public-candidates': query_introspector_jvm_all_public_candidates, 'optimal-targets': query_introspector_for_optimal_targets, 'test-migration': query_introspector_for_tests, + 'all-public-candidates': query_introspector_all_public_candidates, } return oracle_dict @@ -102,7 +103,7 @@ def set_introspector_endpoints(endpoint): INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, \ INTROSPECTOR_ALL_HEADER_FILES, INTROSPECTOR_ALL_FUNC_TYPES, \ INTROSPECTOR_SAMPLE_XREFS, INTROSPECTOR_ORACLE_EASY_PARAMS, \ - INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES, \ + INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES, \ INTROSPECTOR_ALL_JVM_SOURCE_PATH, INTROSPECTOR_ORACLE_OPTIMAL, \ INTROSPECTOR_HEADERS_FOR_FUNC, \ INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, \ @@ -119,7 +120,7 @@ def set_introspector_endpoints(endpoint): f'{INTROSPECTOR_ENDPOINT}/far-reach-low-cov-fuzz-keyword') INTROSPECTOR_ORACLE_EASY_PARAMS = ( f'{INTROSPECTOR_ENDPOINT}/easy-params-far-reach') - INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES = ( + INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = ( f'{INTROSPECTOR_ENDPOINT}/all-public-candidates') INTROSPECTOR_ORACLE_OPTIMAL = f'{INTROSPECTOR_ENDPOINT}/optimal-targets' INTROSPECTOR_FUNCTION_SOURCE = f'{INTROSPECTOR_ENDPOINT}/function-source-code' @@ -278,7 +279,16 @@ def query_introspector_jvm_all_public_candidates(project: str) -> list[dict]: constructor candidates. """ return query_introspector_oracle( - project, INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES) + project, INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES) + + +def query_introspector_all_public_candidates(project: str) -> list[dict]: + """Queries Fuzz Introspector for all public accessible function or + constructor candidates. + """ + #TODO May combine this with query_introspector_jvm_all_public_candidates + return query_introspector_oracle( + project, INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES) def query_introspector_for_targets(project, target_oracle) -> list[Dict]: @@ -859,7 +869,7 @@ def populate_benchmarks_using_introspector(project: str, language: str, # arguments. Thus skipping it. continue - if language == 'jvm': + elif language == 'jvm': # Retrieve list of source file from introspector src_path_list = query_introspector_jvm_source_path(project) if src_path_list: @@ -872,7 +882,8 @@ def populate_benchmarks_using_introspector(project: str, language: str, if src_file not in src_path_list: logger.error('error: %s %s', filename, interesting.keys()) continue - elif language != 'python' and interesting and filename not in [ + + elif language != 'rust' and interesting and filename not in [ os.path.basename(i) for i in interesting.keys() ]: # TODO: Bazel messes up paths to include "/proc/self/cwd/..." diff --git a/data_prep/project_src.py b/data_prep/project_src.py index dbc63e6a5..476c796ee 100755 --- a/data_prep/project_src.py +++ b/data_prep/project_src.py @@ -96,6 +96,8 @@ def _get_harness(src_file: str, out: str, language: str) -> tuple[str, str]: return '', '' if language.lower() == 'python' and 'atheris.Fuzz()' not in content: return '', '' + if language.lower() == 'rust' and 'fuzz_target!' not in content: + return '', '' short_path = src_file[len(out):] return short_path, content @@ -307,6 +309,12 @@ def _identify_fuzz_targets(out: str, interesting_filenames: list[str], interesting_filepaths.append(path) if path.endswith('.py'): potential_harnesses.append(path) + elif language == 'rust': + # For Rust + if path.endswith(tuple(interesting_filenames)): + interesting_filepaths.append(path) + if path.endswith('.rs'): + potential_harnesses.append(path) else: # For C/C++ short_path = path[len(out):] diff --git a/experiment/benchmark.py b/experiment/benchmark.py index f3d4c1071..aac924349 100644 --- a/experiment/benchmark.py +++ b/experiment/benchmark.py @@ -204,6 +204,13 @@ def __init__(self, # zipp-zipp.difference. self.id = self.id.replace('._', '.') + if self.language = 'rust': + # For rust projects, double colon (::) is sometime used to identify + # crate, impl or trait name of a function. This could affect the + # benchmark_id and cause OSS-Fuzz build failed. + # Special handling of benchmark_id is needed to avoid this situation. + self.id = self.id.replace('::', '-') + def __str__(self): return (f'Benchmark str: 'jvm': 'jacoco.xml', 'python': 'all_cov.json', 'c++': f'{self.benchmark.target_name}.covreport', - 'c': f'{self.benchmark.target_name}.covreport' + 'c': f'{self.benchmark.target_name}.covreport', + 'rust': f'{self.benchmark.target_name}.covreport', } return os.path.join(get_build_artifact_dir(project_name, @@ -696,6 +699,7 @@ def _extract_local_textcoverage_data(self, 'python': 'r', 'c': 'rb', 'c++': 'rb', + 'rust': 'rb', } with open(local_textcov_location, language_modes.get(self.benchmark.language, 'rb')) as f: @@ -1033,7 +1037,7 @@ def build_and_run_cloud( self._copy_textcov_to_workdir(bucket, textcov_blob_path, generated_target_name) else: - # C/C++ + # C/C++/Rust blob = bucket.blob(textcov_blob_path) if blob.exists(): with blob.open('rb') as f: @@ -1075,6 +1079,7 @@ def _get_cloud_textcov_path(self, coverage_name: str) -> str: if self.benchmark.language == 'python': return f'{coverage_name}/textcov_reports/all_cov.json' + # For C/C++/Rust return (f'{coverage_name}/textcov_reports/{self.benchmark.target_name}' '.covreport') diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index a42aebb99..f63a5ecab 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -1054,6 +1054,99 @@ def post_process_generated_code(self, generated_code: str) -> str: return generated_code +class DefaultRustTemplateBuilder(PromptBuilder): + """Default builder for Rust projects.""" + + def __init__(self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.project_url = oss_fuzz_checkout.get_project_repository( + self.benchmark.project) + + # Load templates. + self.base_template_file = self._find_template(template_dir, + 'rust_base.txt') + self.problem_template_file = self._find_template(template_dir, + 'rust_problem.txt') + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _format_target(self, signature: str) -> str: + """Format the target function for the prompts creation.""" + target = self._get_template(self.problem_template_file) + arg_count = len(self.benchmark.params) + arg_type = [arg_dict['type'] for arg_dict in self.benchmark.params] + + target = target.replace('{FUNCTION_SIGNATURE}', signature) + target = target.replace('{ARG_COUNT}', str(arg_count)) + target = target.replace('{ARG_TYPE}', ','.join(arg_type)) + + return target + + def _format_problem(self, signature: str) -> str: + """Formats a problem based on the prompt template.""" + base = self._get_template(self.base_template_file) + target_str = self._format_target(signature) + + problem = base + target_str + problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) + problem = problem.replace("{PROJECT_URL}", self.project_url) + + return problem + + def _prepare_prompt(self, prompt_str: str): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(prompt_str) + + def build(self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + final_problem = self._format_problem(self.benchmark.function_signature) + self._prepare_prompt(final_problem) + return self._prompt + + def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, + error_desc: Optional[str], + errors: list[str]) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for rust project now. + return self._prompt + + def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, + crash_info: str, crash_func: dict) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for rust project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # Do nothing for rust project now. + return generated_code + + class JvmErrorFixingBuilder(PromptBuilder): """Prompt builder for fixing JVM harness with complication error.""" diff --git a/prompts/template_xml/rust_base.txt b/prompts/template_xml/rust_base.txt new file mode 100644 index 000000000..835faea50 --- /dev/null +++ b/prompts/template_xml/rust_base.txt @@ -0,0 +1,3 @@ +You are a security testing engineer who wants to write a Rust program to execute all lines in a given method by defining and initialising its parameters and necessary objects in a suitable way before fuzzing the method. +The tag contains information of the target method to invoke. +The tag contains additional requirements that you MUST follow for this code generation. diff --git a/prompts/template_xml/rust_problem.txt b/prompts/template_xml/rust_problem.txt new file mode 100644 index 000000000..53dd102ae --- /dev/null +++ b/prompts/template_xml/rust_problem.txt @@ -0,0 +1,33 @@ + +Your goal is to write a fuzzing harness for the provided method signature to fuzz the method with random data. It is important that the provided solution compiles and actually calls the function specified by the method signature: + + +{FUNCTION_SIGNATURE} + +The target function is belonging to the Rust project {PROJECT_NAME} ({PROJECT_URL}). +This function requires {ARG_COUNT} arguments. You must prepare them with random seeded data. +Here is a list of types for all arguments in order, separated by comma. You MUST preserve the modifiers. +{ARG_TYPE} + + +Try as many variations of these inputs as possible. +Try creating the harness as complex as possible. +Try adding some nested loop to invoke the target method for multiple times. +The generated fuzzing harness should be wrapped with the tag. +Please avoid using any multithreading or multi-processing approach. +You MUST create the fuzzing harness using Cargo-Fuzz approach. +you MUST use the #![no_main] tag. +You MUST use the libfuzzer_sys::fuzz_target crate. +You MUST include the fuzz_target macro to include all fuzzing statements. +The following is a sample of the fuzzing harness. + +#![no_main] +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // This is the macro acts as the entry point for the fuzzing harness. + // Add fuzzing logic here. +}); + + + diff --git a/run_one_experiment.py b/run_one_experiment.py index c29bb018a..82415a0a2 100644 --- a/run_one_experiment.py +++ b/run_one_experiment.py @@ -304,6 +304,12 @@ def generate_targets_for_analysis( # For Python projects builder = prompt_builder.DefaultPythonTemplateBuilder( model, benchmark, template_dir) + + elif benchmark.language == 'rust': + # For Rust projects + builder = prompt_builder.DefaultRustTemplateBuilder( + model, benchmark, template_dir) + elif prompt_builder_to_use == 'CSpecific': builder = prompt_builder.CSpecificBuilder(model, benchmark, template_dir) else: