Skip to content

Commit

Permalink
Add basic results evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroJSilva2001 committed May 26, 2024
1 parent 3dc94d7 commit b5cfa97
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 15 deletions.
19 changes: 12 additions & 7 deletions experiments/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@

ARRAY_SIZE = 10000

class Error():
def __init__(self, kind, message):
self.kind = kind
self.message = message

class DeclLexerErrorListener(ErrorListener):
def __init__(self):
super().__init__()
self.errors = []

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception(f"Lexing error at line {line}, column {column}: {msg}")
self.errors.append(Error("lexer", f"Lexing error at line {line}, column {column}: {msg}"))

def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
raise Exception("Ambiguity")
raise Exception("Ambiguity error")

def has_errors(self):
return len(self.errors) > 0
Expand All @@ -44,10 +49,10 @@ def __init__(self):
self.errors = []

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception(f"Syntax error at line {line}, column {column}: {msg}")
self.errors.append(Error("parser", f"Syntax error at line {line}, column {column}: {msg}"))

def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
raise Exception("Ambiguity")
raise Exception("Ambiguity error")

def has_errors(self):
return len(self.errors) > 0
Expand Down Expand Up @@ -289,12 +294,12 @@ def _get_type_specifier_as_c_type(type_specifier):
if lexer_error_listener.has_errors():
for error in lexer_error_listener.get_errors():
print(error)
print(error.message)
exit(1)
if parser_error_listener.has_errors():
for error in parser_error_listener.get_errors():
print(error)
print(error.message)
exit(1)
decls = DeclParser.get_declarations_as_obj(decls_tree)
Expand Down
60 changes: 60 additions & 0 deletions experiments/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from declaration import DeclParser, DeclConverter

from utils import compile_with_clang

def evaluate_experiment_result(llm_answer, code_snippet):
decls_tree, lexer_error_listener, parser_error_listener = DeclParser.parse(llm_answer)

errors = []
errors += lexer_error_listener.get_errors()
errors += parser_error_listener.get_errors()

'''
if lexer_error_listener.has_errors():
print("Lexer errors found")
for error in lexer_error_listener.get_errors():
print(error)
if parser_error_listener.has_errors():
print("Parser errors found")
for error in parser_error_listener.get_errors():
print(error)
'''

if len(errors) > 0:
print("Errors found")
for error in errors:
print(error.message)
return


decls = DeclParser.get_declarations_as_obj(decls_tree)


#for decl in decls:
#print(decl)

c_decls = DeclConverter.get_declarations_as_c_decls(decls)

# TODO do we need to add the main? currently we are adding it because the linker is complaining
patched_code = "\n".join(c_decls) + code_snippet

with open("temp.c", "w") as f:
f.write(patched_code)
f.close()

temp_file_path = "temp.c"

code_compiled, err = compile_with_clang(temp_file_path)

print(patched_code)
if code_compiled:
print("Code compiled successfully")
else:
print("Compilation failed")
print(err)
#for c_decl in c_decls:
# print(c_decl, end="\n")


52 changes: 45 additions & 7 deletions experiments/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from error import extract_errors, parse_errors, get_errors_as_str
from utils import has_extension, get_last_path_component, get_files_from_path, remove_first_path_component
from declaration import DeclParser
from evaluation import evaluate_experiment_result

CODE_TAG = "<CODE>"
ERRORS_TAG = "<ERRORS>"
Expand Down Expand Up @@ -309,8 +310,47 @@ def run_experiments(paths: Paths):
if continue_exec.lower() in ["n", "no"]:
break

def analyse_experiments(paths: Paths):
print("Analysing experiments..")
def evaluate_results(project_id, paths: Paths):
print("Evaluating experiment results..")

raw_results_filenames = get_files_from_path(paths.raw_results_path)

projects = get_projects(paths.metadata_path)
project = projects[project_id]

for raw_result_filename in raw_results_filenames:
print(raw_result_filename)
if not has_extension(raw_result_filename, JSON_EXTENSION):
continue

raw_result_path = os.path.join(paths.raw_results_path, raw_result_filename)

try:
with open(raw_result_path, 'r') as raw_result_file:
raw_result = json.load(raw_result_file)

#print(raw_result)
#if not raw_result:
# print(f"Couldn't load result file '{raw_result_filename}'")

experiment_id = raw_result["experiment"]
experiment_snippet = project["experiments"][experiment_id]["snippet"]

with open(os.path.join(paths.snippets_path, project["metadata"]["snippets"], experiment_snippet)) as file:
code_snippet = file.read()


llm_answer = raw_result["llm_answer"]

#print(llm_answer)

evaluate_experiment_result(llm_answer, code_snippet)

raw_result_file.close()

except Exception as e:
print(f"Couldn't open result file '{raw_result_filename}'")
raise e

# LLM
def build_prompt(prompt_template, snippet, errors):
Expand Down Expand Up @@ -498,9 +538,7 @@ def test_all_projects(projects, prompt_template, prompt_template_filename, snipp

save_experiment_result(chat_completion, gen_time, raw_results_path, experiment_name, project_name, project_org, prompt_template_filename)
#print(prompt)

def analyse_results():
pass


def main():
experiments_path = select_experiments_path()
Expand Down Expand Up @@ -534,7 +572,7 @@ def main():
[0] Create new project metadata
[1] Create new code snippet metadata
[2] Run experiments
[3] Analyse experiments
[3] Analyse experiment results
[4] Exit
"""
)
Expand All @@ -554,7 +592,7 @@ def main():
elif option == "2":
run_experiments(paths)
elif option == "3":
analyse_experiments(paths)
evaluate_results("doom", paths)
elif option == "4":
break

Expand Down
10 changes: 9 additions & 1 deletion experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import subprocess
import os
import logging

logger = logging.getLogger(__name__)

def has_extension(file_path, target_extension):
_, file_extension = os.path.splitext(file_path)
Expand All @@ -24,8 +27,13 @@ def get_last_path_component(path):
def get_files_from_path(directory_path):
return [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

# Used to:
# 1. retrieve compiler errors
# 2. compile the patched code
def compile_with_clang(source_file_path, output_file="a.out"):
compile_command = ["clang-15", source_file_path, "-o", output_file, "-ferror-limit=0"]
compile_command = ["clang", source_file_path, "-ferror-limit=0", "-S"]

# TODO assert at least version 15 (for errors on implicit function declarations)

try:
result = subprocess.run(compile_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
Expand Down

0 comments on commit b5cfa97

Please sign in to comment.